Gemma-4-WebGPU / src /hooks /useModel.js
shreyask's picture
Upload folder using huggingface_hub
0c2fc21 verified
raw
history blame
2.48 kB
// web/src/hooks/useModel.js
import { useState, useEffect, useRef, useCallback } from "react";
import { read_audio } from "@huggingface/transformers";
export function useModel() {
const [status, setStatus] = useState("idle"); // idle | webgpu-available | webgpu-unavailable | loading | ready | generating | error
const [loadProgress, setLoadProgress] = useState(null);
const [error, setError] = useState(null);
const workerRef = useRef(null);
const callbacksRef = useRef(null);
useEffect(() => {
const worker = new Worker(new URL("../worker.js", import.meta.url), {
type: "module",
});
worker.onmessage = (e) => {
const { type, ...data } = e.data;
switch (type) {
case "status":
setStatus(data.status);
if (data.status === "ready") setLoadProgress(null);
break;
case "progress":
setLoadProgress(data);
break;
case "error":
setError(data.message);
setStatus("error");
callbacksRef.current?.onComplete?.("", data.message);
break;
case "update":
callbacksRef.current?.onUpdate?.(data.text);
break;
case "complete":
setStatus("ready");
callbacksRef.current?.onComplete?.(data.text);
callbacksRef.current = null;
break;
}
};
workerRef.current = worker;
return () => worker.terminate();
}, []);
const checkWebGPU = useCallback(() => {
workerRef.current?.postMessage({ type: "check" });
}, []);
const loadModel = useCallback(() => {
workerRef.current?.postMessage({ type: "load" });
}, []);
const generate = useCallback(async ({ messages, imageUrl, audioUrl, enableThinking, onUpdate, onComplete }) => {
callbacksRef.current = { onUpdate, onComplete };
let audioData = null;
if (audioUrl) {
try {
audioData = await read_audio(audioUrl, 16000);
} catch (err) {
console.error("Audio decode failed:", err);
}
}
const msg = {
type: "generate",
messages,
imageUrl: imageUrl || null,
audioData,
enableThinking: enableThinking || false,
};
workerRef.current?.postMessage(msg, audioData ? [audioData.buffer] : []);
}, []);
const interrupt = useCallback(() => {
workerRef.current?.postMessage({ type: "interrupt" });
}, []);
return { status, loadProgress, error, checkWebGPU, loadModel, generate, interrupt };
}