import { useState, useEffect, useRef, useCallback } from "react"; import { Streamdown } from "streamdown"; import { code } from "@streamdown/code"; import { mermaid } from "@streamdown/mermaid"; import { createMathPlugin } from "@streamdown/math"; import { cjk } from "@streamdown/cjk"; import "streamdown/styles.css"; import "katex/dist/katex.min.css"; const MODELS = [ { id: "1.7b", name: "Bonsai 1.7B", params: "1.7B", size: "290 MB", blurb: "Pocket-class. Built for wearables and always-on agents.", }, { id: "4b", name: "Bonsai 4B", params: "4B", size: "584 MB", blurb: "The sweet spot. Strong reasoning at on-device latency.", comingSoon: true, }, { id: "8b", name: "Bonsai 8B", params: "8B", size: "1.2 GB", blurb: "Datacenter-grade reasoning, in your browser tab.", comingSoon: true, }, ]; const formatBytes = (bytes) => { if (!bytes) return "0 MB"; if (bytes >= 1e9) return `${(bytes / 1e9).toFixed(2)} GB`; return `${(bytes / 1e6).toFixed(0)} MB`; }; const math = createMathPlugin({ singleDollarTextMath: true }); const STREAMDOWN_PLUGINS = { code, mermaid, math, cjk }; const PRISM_GLYPH_CLASS = "h-9 w-9 overflow-hidden opacity-90 [clip-path:polygon(50%_4%,100%_100%,0%_100%)] bg-[radial-gradient(circle_at_50%_18%,rgba(255,255,255,0.3),transparent_28%),linear-gradient(180deg,var(--prism-2)_0%,var(--prism-1)_42%,var(--prism-6)_100%)] drop-shadow-[0_0_18px_rgba(255,184,77,0.18)]"; function SendIcon() { return ( ); } function ResetIcon() { return ( ); } function StopIcon() { return ( ); } export default function PrismDemo() { const canvasRef = useRef(null); const mouseRef = useRef({ x: -9999, y: -9999, inside: false }); const cellsRef = useRef([]); const wavesRef = useRef([]); const rafRef = useRef(null); const playgroundRef = useRef(null); const inputRef = useRef(null); const messagesRef = useRef(null); const autoScrollRef = useRef(true); const optimizeMessageTimeoutRef = useRef(null); // Playground stage machine const [stage, setStage] = useState("select"); // 'select' | 'loading' | 'chat' | 'error' const [selectedId, setSelectedId] = useState("1.7b"); const [loadProgress, setLoadProgress] = useState(0); const [loadMessage, setLoadMessage] = useState("Fetching weights"); const [loadLoaded, setLoadLoaded] = useState(0); const [loadTotal, setLoadTotal] = useState(0); const [errorMessage, setErrorMessage] = useState(null); const messagesHistoryRef = useRef([]); // Chat state const [messages, setMessages] = useState([]); const [input, setInput] = useState(""); const [isThinking, setIsThinking] = useState(false); const [isStreaming, setIsStreaming] = useState(false); const [tps, setTps] = useState(null); const workerRef = useRef(null); const [fontsReady, setFontsReady] = useState(false); useEffect(() => { if (!document.fonts?.load) { setFontsReady(true); return; } Promise.all([ document.fonts.load('400 132px "Instrument Serif"'), document.fonts.load('italic 400 132px "Instrument Serif"'), ]).finally(() => setFontsReady(true)); }, []); const selected = MODELS.find((m) => m.id === selectedId) || MODELS[0]; const scrollToPlayground = useCallback(() => { playgroundRef.current?.scrollIntoView({ behavior: "smooth", block: "start", }); }, []); useEffect(() => { if (!workerRef.current) { workerRef.current = new Worker(new URL("./worker.js", import.meta.url), { type: "module", }); } const worker = workerRef.current; const onMessage = (e) => { const d = e.data; switch (d.status) { case "progress_total": setLoadProgress(d.progress); setLoadLoaded(d.loaded); setLoadTotal(d.total); if (Number(d.progress) >= 100) { clearTimeout(optimizeMessageTimeoutRef.current); optimizeMessageTimeoutRef.current = setTimeout(() => { setLoadMessage("Optimizing model for 1-bit execution"); }, 100); } break; case "loading": setLoadMessage(d.data); break; case "ready": clearTimeout(optimizeMessageTimeoutRef.current); setStage("chat"); setTimeout(() => inputRef.current?.focus(), 200); break; case "start": setIsThinking(false); setIsStreaming(true); setTps(null); setMessages((m) => [...m, { role: "assistant", content: "" }]); break; case "update": if (d.tps != null) setTps(d.tps); setMessages((m) => { const copy = [...m]; const last = copy[copy.length - 1]; copy[copy.length - 1] = { ...last, content: last.content + d.output, }; return copy; }); break; case "complete": setIsStreaming(false); break; case "error": clearTimeout(optimizeMessageTimeoutRef.current); setErrorMessage(d.data); setStage("error"); setIsThinking(false); setIsStreaming(false); break; } }; worker.addEventListener("message", onMessage); return () => { clearTimeout(optimizeMessageTimeoutRef.current); worker.removeEventListener("message", onMessage); }; }, []); const startLoading = () => { clearTimeout(optimizeMessageTimeoutRef.current); setStage("loading"); setLoadProgress(0); setLoadMessage("Fetching weights"); setLoadLoaded(0); setLoadTotal(0); workerRef.current?.postMessage({ type: "load", data: selectedId }); }; // --- Background canvas: digits + glistening waves --- useEffect(() => { const canvas = canvasRef.current; if (!canvas) return; const ctx = canvas.getContext("2d"); const dpr = window.devicePixelRatio || 1; const CELL_W = 16; const CELL_H = 22; const PRISM = [ [255, 122, 92], [255, 184, 77], [196, 217, 46], [77, 208, 196], [124, 142, 232], [182, 123, 232], ]; const setupCells = () => { const w = canvas.clientWidth; const h = canvas.clientHeight; canvas.width = w * dpr; canvas.height = h * dpr; ctx.scale(dpr, dpr); const cols = Math.ceil(w / CELL_W); const rows = Math.ceil(h / CELL_H); const cells = []; for (let r = 0; r < rows; r++) { for (let c = 0; c < cols; c++) { cells.push({ x: c * CELL_W + CELL_W / 2, y: r * CELL_H + CELL_H / 2 + 6, val: Math.random() > 0.5 ? "1" : "0", nextFlip: performance.now() + Math.random() * 8000 + 3000, phase: Math.random() * Math.PI * 2, }); } } cellsRef.current = cells; }; const handleResize = () => { ctx.setTransform(1, 0, 0, 1, 0, 0); setupCells(); }; const handleMove = (e) => { const rect = canvas.getBoundingClientRect(); mouseRef.current.x = e.clientX - rect.left; mouseRef.current.y = e.clientY - rect.top; mouseRef.current.inside = true; }; const handleLeave = () => { mouseRef.current.inside = false; mouseRef.current.x = -9999; mouseRef.current.y = -9999; }; setupCells(); window.addEventListener("resize", handleResize); window.addEventListener("mousemove", handleMove); canvas.addEventListener("mouseleave", handleLeave); let waveColorIdx = 0; const spawnWave = () => { const w = canvas.clientWidth; const h = canvas.clientHeight; const angle = (Math.random() - 0.5) * (Math.PI / 3) + Math.PI / 2; const nx = Math.cos(angle); const ny = Math.sin(angle); const dots = [0, w * nx, h * ny, w * nx + h * ny]; const minDot = Math.min(...dots); const maxDot = Math.max(...dots); const direction = Math.random() > 0.5 ? 1 : -1; const startPos = direction > 0 ? minDot - 200 : maxDot + 200; const endPos = direction > 0 ? maxDot + 200 : minDot - 200; const speed = 220 + Math.random() * 140; const color = PRISM[waveColorIdx % PRISM.length]; waveColorIdx++; wavesRef.current.push({ nx, ny, pos: startPos, endPos, direction, speed, color, bandWidth: 110, startedAt: performance.now(), flippedSet: new Set(), }); }; let nextWaveAt = performance.now() + 2200; let lastT = performance.now(); const render = (now) => { const dt = Math.min(now - lastT, 64); lastT = now; const w = canvas.clientWidth; const h = canvas.clientHeight; ctx.clearRect(0, 0, w, h); ctx.font = "11px 'JetBrains Mono', monospace"; ctx.textAlign = "center"; ctx.textBaseline = "alphabetic"; if (now > nextWaveAt) { spawnWave(); nextWaveAt = now + 5500 + Math.random() * 3500; } const activeWaves = []; for (const wv of wavesRef.current) { wv.pos += wv.direction * wv.speed * (dt / 1000); const stillOnscreen = (wv.direction > 0 && wv.pos < wv.endPos) || (wv.direction < 0 && wv.pos > wv.endPos); if (stillOnscreen) activeWaves.push(wv); } wavesRef.current = activeWaves; const mx = mouseRef.current.x; const my = mouseRef.current.y; const LENS_R = 160; const LENS_R2 = LENS_R * LENS_R; const cells = cellsRef.current; for (let i = 0; i < cells.length; i++) { const cell = cells[i]; if (now > cell.nextFlip) { cell.val = cell.val === "1" ? "0" : "1"; cell.nextFlip = now + 4000 + Math.random() * 9000; } const pulse = 0.5 + 0.5 * Math.sin(now * 0.0006 + cell.phase); let opacity = 0.02 + pulse * 0.015; let r = 235, g = 229, b = 216; let waveBoost = 0; let tintR = 0, tintG = 0, tintB = 0, tintW = 0; for (const wv of activeWaves) { const cellPos = cell.x * wv.nx + cell.y * wv.ny; const d = cellPos - wv.pos; const ad = Math.abs(d); if (ad < wv.bandWidth) { const sigma = 30; const glow = Math.exp(-(d * d) / (2 * sigma * sigma)); const crest = Math.max(0, 1 - ad / wv.bandWidth); const intensity = glow * 0.85 + crest * 0.15; waveBoost += intensity * 0.85; tintR += wv.color[0] * intensity; tintG += wv.color[1] * intensity; tintB += wv.color[2] * intensity; tintW += intensity; if (glow > 0.85 && !wv.flippedSet.has(i) && Math.random() < 0.18) { cell.val = cell.val === "1" ? "0" : "1"; wv.flippedSet.add(i); } } } let lensFalloff = 0; if (mouseRef.current.inside) { const dx = cell.x - mx; const dy = cell.y - my; const d2 = dx * dx + dy * dy; if (d2 < LENS_R2) { lensFalloff = 1 - Math.sqrt(d2) / LENS_R; } } if (lensFalloff > 0) { opacity = Math.max(opacity, 0.08 + lensFalloff * 0.4); } if (waveBoost > 0) { opacity = Math.min(0.95, opacity + waveBoost); if (tintW > 0) { const tintMix = Math.min(1, waveBoost * 1.2); const tr = tintR / tintW; const tg = tintG / tintW; const tb = tintB / tintW; r = r * (1 - tintMix) + tr * tintMix; g = g * (1 - tintMix) + tg * tintMix; b = b * (1 - tintMix) + tb * tintMix; } } ctx.fillStyle = `rgba(${r | 0}, ${g | 0}, ${b | 0}, ${opacity})`; ctx.fillText(cell.val, cell.x, cell.y); } rafRef.current = requestAnimationFrame(render); }; rafRef.current = requestAnimationFrame(render); return () => { cancelAnimationFrame(rafRef.current); window.removeEventListener("resize", handleResize); window.removeEventListener("mousemove", handleMove); canvas.removeEventListener("mouseleave", handleLeave); }; }, []); useEffect(() => { messagesHistoryRef.current = messages; }, [messages]); const sendMessage = (text) => { const trimmed = text.trim(); if (!trimmed || isStreaming || isThinking) return; setInput(""); const nextHistory = [ ...messagesHistoryRef.current, { role: "user", content: trimmed }, ]; setMessages(nextHistory); setIsThinking(true); workerRef.current?.postMessage({ type: "generate", data: nextHistory }); }; const interruptGeneration = () => { workerRef.current?.postMessage({ type: "interrupt" }); }; const resetChat = () => { setInput(""); setMessages([]); messagesHistoryRef.current = []; setIsThinking(false); setIsStreaming(false); setTps(null); autoScrollRef.current = true; workerRef.current?.postMessage({ type: "reset" }); setTimeout(() => inputRef.current?.focus(), 0); }; const handleMessagesScroll = useCallback(() => { const container = messagesRef.current; if (!container) return; const distanceFromBottom = container.scrollHeight - container.scrollTop - container.clientHeight; autoScrollRef.current = distanceFromBottom < 96; }, []); useEffect(() => { if (!autoScrollRef.current) return; const container = messagesRef.current; if (!container) return; const id = requestAnimationFrame(() => { container.scrollTo({ top: container.scrollHeight, behavior: isStreaming ? "auto" : "smooth", }); }); return () => cancelAnimationFrame(id); }, [messages, isThinking, isStreaming]); const SUGGESTIONS = [ "Who are you?", "Write a short poem about AI.", "What is the capital of France?", "Solve x^2 - 5x + 6 = 0.", ]; const handleKeyDown = (e) => { if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); sendMessage(input); } }; let tpsDisplay = "—"; if (isThinking) tpsDisplay = "···"; else if (tps !== null) tpsDisplay = tps.toFixed(1); return ( <>
Run 1-bit Bonsai LLMs (1.7B, 4B, 8B) entirely locally in your browser on WebGPU, powered by Transformers.js.
Each Bonsai model runs entirely in your browser via WebGPU. Pick a size — smaller loads faster, larger reasons better.
{errorMessage}