| import React, { useState } from 'react'; |
| import * as tflite from '@tensorflow/tfjs-tflite'; |
|
|
| function TFLiteObjectDetection() { |
| const [averageTime, setAverageTime] = useState(null); |
| const [loading, setLoading] = useState(false); |
| const [images, setImages] = useState([]); |
| const [model, setModel] = useState(null); |
|
|
| const handleFileChange = (event) => { |
| const files = Array.from(event.target.files); |
| setImages(files.slice(0, 10)); |
| }; |
|
|
| const loadModel = async () => { |
| try { |
| |
| const loadedModel = await tflite.loadTFLiteModel('./model.tflite'); |
| setModel(loadedModel); |
| console.log('Model loaded successfully!'); |
| } catch (error) { |
| console.error('Error loading TFLite model:', error); |
| } |
| }; |
|
|
| const runBenchmark = async () => { |
| if (!model || images.length === 0) { |
| alert('Please load the model and upload 10 images.'); |
| return; |
| } |
|
|
| setLoading(true); |
| const repetitions = 50; |
| let totalInferenceTime = 0; |
|
|
| try { |
| for (let rep = 0; rep < repetitions; rep++) { |
| console.log(`Repetition ${rep + 1} of ${repetitions}`); |
|
|
| for (const imageFile of images) { |
| const startTime = performance.now(); |
|
|
| |
| const inputTensor = await preprocessImage(imageFile); |
|
|
| |
| const output = model.predict(inputTensor); |
|
|
| const endTime = performance.now(); |
| totalInferenceTime += endTime - startTime; |
|
|
| |
| console.log('Inference output:', output); |
| } |
| } |
|
|
| |
| const avgInferenceTime = totalInferenceTime / (repetitions * images.length); |
| setAverageTime(avgInferenceTime); |
| } catch (error) { |
| console.error('Error during inference:', error); |
| } |
|
|
| setLoading(false); |
| }; |
|
|
| const preprocessImage = async (imageFile) => { |
| return new Promise((resolve) => { |
| const img = new Image(); |
| const reader = new FileReader(); |
|
|
| reader.onload = () => { |
| img.src = reader.result; |
| }; |
|
|
| img.onload = () => { |
| const canvas = document.createElement('canvas'); |
| const context = canvas.getContext('2d'); |
|
|
| |
| const modelInputWidth = 320; |
| const modelInputHeight = 320; |
| canvas.width = modelInputWidth; |
| canvas.height = modelInputHeight; |
|
|
| context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight); |
|
|
| const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight); |
|
|
| |
| const floatData = new Float32Array(imageData.data.length / 4); |
| for (let i = 0, j = 0; i < imageData.data.length; i += 4) { |
| floatData[j++] = imageData.data[i] / 255; |
| floatData[j++] = imageData.data[i + 1] / 255; |
| floatData[j++] = imageData.data[i + 2] / 255; |
| } |
|
|
| |
| resolve(new tflite.Tensor(floatData, [1, modelInputHeight, modelInputWidth, 3])); |
| }; |
|
|
| reader.readAsDataURL(imageFile); |
| }); |
| }; |
|
|
| return React.createElement( |
| 'div', |
| null, |
| React.createElement('h1', null, 'Object Detection Benchmark (TFLite)'), |
| React.createElement('button', { onClick: loadModel, disabled: model !== null }, 'Load Model'), |
| React.createElement('input', { |
| type: 'file', |
| multiple: true, |
| accept: 'image/*', |
| onChange: handleFileChange, |
| }), |
| React.createElement( |
| 'button', |
| { onClick: runBenchmark, disabled: loading || !model || images.length === 0 }, |
| loading ? 'Running Benchmark...' : 'Start Benchmark' |
| ), |
| React.createElement( |
| 'div', |
| null, |
| averageTime !== null |
| ? React.createElement( |
| 'h2', |
| null, |
| `Average Inference Time: ${averageTime.toFixed(2)} ms` |
| ) |
| : null |
| ), |
| React.createElement( |
| 'ul', |
| null, |
| images.map((img, index) => |
| React.createElement('li', { key: index }, img.name) |
| ) |
| ) |
| ); |
| } |
|
|
| export default TFLiteObjectDetection; |
|
|