av-codes commited on
Commit
ed79481
·
verified ·
1 Parent(s): fa98a90

add model toggle: DistilBERT v3 + HRM-Text with ONNX Runtime Web

Browse files
Files changed (1) hide show
  1. index.html +214 -29
index.html CHANGED
@@ -74,6 +74,48 @@
74
  background: var(--accent);
75
  }
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  .input-area {
78
  position: relative;
79
  }
@@ -111,7 +153,7 @@
111
  font-variant-numeric: tabular-nums;
112
  }
113
 
114
- button {
115
  padding: 0.625rem 1.5rem;
116
  background: var(--accent);
117
  color: #fff;
@@ -123,8 +165,8 @@
123
  transition: opacity 0.15s;
124
  }
125
 
126
- button:hover { opacity: 0.9; }
127
- button:disabled { opacity: 0.4; cursor: not-allowed; }
128
 
129
  .result {
130
  margin-top: 1.5rem;
@@ -287,19 +329,30 @@
287
  <div class="loading-overlay" id="loading">
288
  <div class="spinner"></div>
289
  <div class="loading-text" id="loading-text">Loading model...</div>
290
- <div class="loading-detail">~65 MB quantized DistilBERT (one-time download)</div>
291
  </div>
292
 
293
  <div class="container">
294
  <header>
295
  <h1>Prompt Injection Detector</h1>
296
- <p class="subtitle">Detects prompt injection attacks in text using a DistilBERT model fine-tuned on 476K adversarial samples. Runs entirely in your browser.</p>
297
  <div class="model-badge">
298
  <span class="dot"></span>
299
  <span id="status-text">Loading...</span>
300
  </div>
301
  </header>
302
 
 
 
 
 
 
 
 
 
 
 
 
303
  <div class="input-area">
304
  <textarea
305
  id="input"
@@ -308,7 +361,7 @@
308
  ></textarea>
309
  <div class="controls">
310
  <span class="char-count" id="char-count">0 chars</span>
311
- <button id="analyze" disabled>Analyze</button>
312
  </div>
313
  </div>
314
 
@@ -336,20 +389,25 @@
336
  </div>
337
 
338
  <footer>
339
- <p>Model: <a href="https://huggingface.co/av-codes/prompt-injection-detector-v2-bordair">av-codes/prompt-injection-detector-v2-bordair</a></p>
340
- <p>DistilBERT 67M · F1 0.9993 · Trained on <a href="https://huggingface.co/datasets/Bordair/bordair-multimodal">bordair-multimodal</a> (476K samples, v1-v5 adversarial attacks)</p>
341
- <p>Powered by <a href="https://huggingface.co/docs/transformers.js">Transformers.js</a> · Inference runs locally in your browser</p>
342
  </footer>
343
  </div>
344
 
345
  <script type="module">
346
  import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.7.0';
 
347
 
348
  env.allowLocalModels = false;
349
 
350
- const MODEL_ID = 'av-codes/prompt-injection-detector-v2-bordair';
 
 
 
351
  const loadingEl = document.getElementById('loading');
352
  const loadingText = document.getElementById('loading-text');
 
353
  const statusText = document.getElementById('status-text');
354
  const inputEl = document.getElementById('input');
355
  const analyzeBtn = document.getElementById('analyze');
@@ -359,54 +417,177 @@
359
  const resultConfidence = document.getElementById('result-confidence');
360
  const confidenceFill = document.getElementById('confidence-fill');
361
  const latencyEl = document.getElementById('latency');
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
- let classifier = null;
 
 
 
 
 
 
364
 
365
- async function loadModel() {
 
 
366
  try {
367
- loadingText.textContent = 'Downloading model...';
368
- classifier = await pipeline('text-classification', MODEL_ID, {
369
  dtype: 'q8',
370
  device: 'wasm',
371
  progress_callback: (progress) => {
372
  if (progress.status === 'progress' && progress.total) {
373
  const pct = Math.round((progress.loaded / progress.total) * 100);
374
- loadingText.textContent = `Downloading... ${pct}%`;
375
  } else if (progress.status === 'ready') {
376
- loadingText.textContent = 'Model ready';
377
  }
378
  }
379
  });
 
 
 
 
 
 
380
 
381
- loadingEl.classList.add('hidden');
382
- statusText.textContent = 'Ready';
383
- analyzeBtn.disabled = false;
 
 
 
 
 
384
  } catch (err) {
385
  loadingText.textContent = `Error: ${err.message}`;
386
  console.error(err);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  }
 
 
388
  }
389
 
390
  async function analyze() {
391
  const text = inputEl.value.trim();
392
- if (!text || !classifier) return;
 
 
393
 
394
  analyzeBtn.disabled = true;
395
  analyzeBtn.textContent = '...';
396
 
397
  const start = performance.now();
398
- const [result] = await classifier(text, { topk: 1 });
399
- const elapsed = Math.round(performance.now() - start);
400
 
401
- const isInjection = result.label === 'injection' || result.label === 'LABEL_1';
402
- const score = result.score;
403
- const pct = (score * 100).toFixed(1);
 
 
404
 
405
- resultEl.className = `result visible ${isInjection ? 'danger' : 'safe'}`;
406
- resultLabel.textContent = isInjection ? 'Injection Detected' : 'Safe';
 
 
 
407
  resultConfidence.textContent = `${pct}%`;
408
  confidenceFill.style.width = `${pct}%`;
409
- latencyEl.textContent = `${elapsed}ms inference`;
 
410
 
411
  analyzeBtn.disabled = false;
412
  analyzeBtn.textContent = 'Analyze';
@@ -433,7 +614,11 @@
433
  });
434
  });
435
 
436
- loadModel();
 
 
 
 
437
  </script>
438
 
439
  </body>
 
74
  background: var(--accent);
75
  }
76
 
77
+ .model-toggle {
78
+ display: flex;
79
+ gap: 0.5rem;
80
+ margin-bottom: 1rem;
81
+ }
82
+
83
+ .model-toggle-btn {
84
+ flex: 1;
85
+ padding: 0.75rem 0.75rem;
86
+ background: var(--surface);
87
+ border: 1px solid var(--border);
88
+ border-radius: 10px;
89
+ color: var(--text-dim);
90
+ font-size: 0.8125rem;
91
+ cursor: pointer;
92
+ transition: border-color 0.15s, color 0.15s, background 0.15s;
93
+ text-align: left;
94
+ line-height: 1.5;
95
+ }
96
+
97
+ .model-toggle-btn:hover {
98
+ border-color: color-mix(in srgb, var(--accent) 50%, transparent);
99
+ }
100
+
101
+ .model-toggle-btn.active {
102
+ border-color: var(--accent);
103
+ color: var(--text);
104
+ background: color-mix(in srgb, var(--accent) 8%, var(--surface));
105
+ }
106
+
107
+ .model-toggle-btn .toggle-name {
108
+ font-weight: 600;
109
+ font-size: 0.875rem;
110
+ display: block;
111
+ margin-bottom: 0.125rem;
112
+ }
113
+
114
+ .model-toggle-btn .toggle-meta {
115
+ font-size: 0.6875rem;
116
+ opacity: 0.7;
117
+ }
118
+
119
  .input-area {
120
  position: relative;
121
  }
 
153
  font-variant-numeric: tabular-nums;
154
  }
155
 
156
+ button.analyze-btn {
157
  padding: 0.625rem 1.5rem;
158
  background: var(--accent);
159
  color: #fff;
 
165
  transition: opacity 0.15s;
166
  }
167
 
168
+ button.analyze-btn:hover { opacity: 0.9; }
169
+ button.analyze-btn:disabled { opacity: 0.4; cursor: not-allowed; }
170
 
171
  .result {
172
  margin-top: 1.5rem;
 
329
  <div class="loading-overlay" id="loading">
330
  <div class="spinner"></div>
331
  <div class="loading-text" id="loading-text">Loading model...</div>
332
+ <div class="loading-detail" id="loading-detail">~65 MB quantized DistilBERT (one-time download)</div>
333
  </div>
334
 
335
  <div class="container">
336
  <header>
337
  <h1>Prompt Injection Detector</h1>
338
+ <p class="subtitle">Detects prompt injection attacks in text using ML models running entirely in your browser.</p>
339
  <div class="model-badge">
340
  <span class="dot"></span>
341
  <span id="status-text">Loading...</span>
342
  </div>
343
  </header>
344
 
345
+ <div class="model-toggle">
346
+ <button class="model-toggle-btn active" id="toggle-distilbert" data-model="distilbert">
347
+ <span class="toggle-name">DistilBERT v3</span>
348
+ <span class="toggle-meta">67M params &middot; F1 0.9961 &middot; mixed data</span>
349
+ </button>
350
+ <button class="model-toggle-btn" id="toggle-hrm" data-model="hrm">
351
+ <span class="toggle-name">HRM-Text</span>
352
+ <span class="toggle-meta">46.2M params &middot; F1 0.9886 &middot; byte-level</span>
353
+ </button>
354
+ </div>
355
+
356
  <div class="input-area">
357
  <textarea
358
  id="input"
 
361
  ></textarea>
362
  <div class="controls">
363
  <span class="char-count" id="char-count">0 chars</span>
364
+ <button class="analyze-btn" id="analyze" disabled>Analyze</button>
365
  </div>
366
  </div>
367
 
 
389
  </div>
390
 
391
  <footer>
392
+ <p><strong>DistilBERT v3:</strong> <a href="https://huggingface.co/av-codes/prompt-injection-detector-v3-mixed">av-codes/prompt-injection-detector-v3-mixed</a> &middot; 67M params &middot; F1 0.9961 &middot; mixed training data (bordair + v1)</p>
393
+ <p><strong>HRM-Text:</strong> <a href="https://huggingface.co/av-codes/prompt-injection-hrm-text">av-codes/prompt-injection-hrm-text</a> &middot; 46.2M params &middot; F1 0.9886 &middot; from-scratch byte-level &middot; bordair data</p>
394
+ <p>Powered by <a href="https://huggingface.co/docs/transformers.js">Transformers.js</a> + <a href="https://onnxruntime.ai/">ONNX Runtime Web</a> &middot; Inference runs locally in your browser</p>
395
  </footer>
396
  </div>
397
 
398
  <script type="module">
399
  import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.7.0';
400
+ import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.min.mjs';
401
 
402
  env.allowLocalModels = false;
403
 
404
+ const DISTILBERT_MODEL_ID = 'av-codes/prompt-injection-detector-v3-mixed';
405
+ const HRM_ONNX_URL = 'https://huggingface.co/av-codes/prompt-injection-hrm-text/resolve/main/onnx/model_fp16.onnx';
406
+ const HRM_MAX_LEN = 2048;
407
+
408
  const loadingEl = document.getElementById('loading');
409
  const loadingText = document.getElementById('loading-text');
410
+ const loadingDetail = document.getElementById('loading-detail');
411
  const statusText = document.getElementById('status-text');
412
  const inputEl = document.getElementById('input');
413
  const analyzeBtn = document.getElementById('analyze');
 
417
  const resultConfidence = document.getElementById('result-confidence');
418
  const confidenceFill = document.getElementById('confidence-fill');
419
  const latencyEl = document.getElementById('latency');
420
+ const toggleDistilbert = document.getElementById('toggle-distilbert');
421
+ const toggleHrm = document.getElementById('toggle-hrm');
422
+
423
+ let activeModel = 'distilbert';
424
+ let distilbertClassifier = null;
425
+ let hrmSession = null;
426
+ let isLoading = false;
427
+
428
+ function showLoading(modelName, sizeInfo) {
429
+ loadingText.textContent = `Loading ${modelName}...`;
430
+ loadingDetail.textContent = sizeInfo;
431
+ loadingEl.classList.remove('hidden');
432
+ }
433
 
434
+ function hideLoading() {
435
+ loadingEl.classList.add('hidden');
436
+ }
437
+
438
+ function updateBadge(text) {
439
+ statusText.textContent = text;
440
+ }
441
 
442
+ async function loadDistilbert() {
443
+ if (distilbertClassifier) return;
444
+ showLoading('DistilBERT v3', '~65 MB quantized (one-time download)');
445
  try {
446
+ distilbertClassifier = await pipeline('text-classification', DISTILBERT_MODEL_ID, {
 
447
  dtype: 'q8',
448
  device: 'wasm',
449
  progress_callback: (progress) => {
450
  if (progress.status === 'progress' && progress.total) {
451
  const pct = Math.round((progress.loaded / progress.total) * 100);
452
+ loadingText.textContent = `Downloading DistilBERT v3... ${pct}%`;
453
  } else if (progress.status === 'ready') {
454
+ loadingText.textContent = 'DistilBERT v3 ready';
455
  }
456
  }
457
  });
458
+ } catch (err) {
459
+ loadingText.textContent = `Error: ${err.message}`;
460
+ console.error(err);
461
+ throw err;
462
+ }
463
+ }
464
 
465
+ async function loadHrm() {
466
+ if (hrmSession) return;
467
+ showLoading('HRM-Text', '~94 MB ONNX (one-time download)');
468
+ try {
469
+ loadingText.textContent = 'Downloading HRM-Text...';
470
+ hrmSession = await ort.InferenceSession.create(HRM_ONNX_URL, {
471
+ executionProviders: ['wasm'],
472
+ });
473
  } catch (err) {
474
  loadingText.textContent = `Error: ${err.message}`;
475
  console.error(err);
476
+ throw err;
477
+ }
478
+ }
479
+
480
+ function softmax(logits) {
481
+ const max = Math.max(...logits);
482
+ const exps = logits.map(x => Math.exp(x - max));
483
+ const sum = exps.reduce((a, b) => a + b, 0);
484
+ return exps.map(x => x / sum);
485
+ }
486
+
487
+ function tokenizeBytes(text) {
488
+ const encoder = new TextEncoder();
489
+ const bytes = encoder.encode(text);
490
+ const len = Math.min(bytes.length, HRM_MAX_LEN);
491
+
492
+ const inputIds = new BigInt64Array(HRM_MAX_LEN);
493
+ const attentionMask = new BigInt64Array(HRM_MAX_LEN);
494
+
495
+ for (let i = 0; i < len; i++) {
496
+ inputIds[i] = BigInt(bytes[i]);
497
+ attentionMask[i] = 1n;
498
+ }
499
+ // Remaining positions are already 0 (padded)
500
+
501
+ return { inputIds, attentionMask };
502
+ }
503
+
504
+ async function analyzeDistilbert(text) {
505
+ const [result] = await distilbertClassifier(text, { topk: 1 });
506
+ const isInjection = result.label === 'injection' || result.label === 'LABEL_1';
507
+ return { isInjection, score: result.score };
508
+ }
509
+
510
+ async function analyzeHrm(text) {
511
+ const { inputIds, attentionMask } = tokenizeBytes(text);
512
+
513
+ const inputTensor = new ort.Tensor('int64', inputIds, [1, HRM_MAX_LEN]);
514
+ const maskTensor = new ort.Tensor('int64', attentionMask, [1, HRM_MAX_LEN]);
515
+
516
+ const results = await hrmSession.run({
517
+ input_ids: inputTensor,
518
+ attention_mask: maskTensor,
519
+ });
520
+
521
+ const logits = Array.from(results.logits.data);
522
+ const probs = softmax(logits);
523
+
524
+ const isInjection = logits[1] > logits[0];
525
+ const score = isInjection ? probs[1] : probs[0];
526
+
527
+ return { isInjection, score };
528
+ }
529
+
530
+ async function switchModel(model) {
531
+ if (model === activeModel && !isLoading) return;
532
+ if (isLoading) return;
533
+
534
+ isLoading = true;
535
+ analyzeBtn.disabled = true;
536
+ activeModel = model;
537
+
538
+ toggleDistilbert.classList.toggle('active', model === 'distilbert');
539
+ toggleHrm.classList.toggle('active', model === 'hrm');
540
+
541
+ updateBadge('Loading...');
542
+
543
+ try {
544
+ if (model === 'distilbert') {
545
+ if (!distilbertClassifier) {
546
+ await loadDistilbert();
547
+ }
548
+ updateBadge('DistilBERT v3 ready');
549
+ } else {
550
+ if (!hrmSession) {
551
+ await loadHrm();
552
+ }
553
+ updateBadge('HRM-Text ready');
554
+ }
555
+ hideLoading();
556
+ analyzeBtn.disabled = false;
557
+ } catch (err) {
558
+ updateBadge('Error');
559
  }
560
+
561
+ isLoading = false;
562
  }
563
 
564
  async function analyze() {
565
  const text = inputEl.value.trim();
566
+ if (!text) return;
567
+ if (activeModel === 'distilbert' && !distilbertClassifier) return;
568
+ if (activeModel === 'hrm' && !hrmSession) return;
569
 
570
  analyzeBtn.disabled = true;
571
  analyzeBtn.textContent = '...';
572
 
573
  const start = performance.now();
574
+ let result;
 
575
 
576
+ if (activeModel === 'distilbert') {
577
+ result = await analyzeDistilbert(text);
578
+ } else {
579
+ result = await analyzeHrm(text);
580
+ }
581
 
582
+ const elapsed = Math.round(performance.now() - start);
583
+ const pct = (result.score * 100).toFixed(1);
584
+
585
+ resultEl.className = `result visible ${result.isInjection ? 'danger' : 'safe'}`;
586
+ resultLabel.textContent = result.isInjection ? 'Injection Detected' : 'Safe';
587
  resultConfidence.textContent = `${pct}%`;
588
  confidenceFill.style.width = `${pct}%`;
589
+ const modelLabel = activeModel === 'distilbert' ? 'DistilBERT v3' : 'HRM-Text';
590
+ latencyEl.textContent = `${elapsed}ms inference · ${modelLabel}`;
591
 
592
  analyzeBtn.disabled = false;
593
  analyzeBtn.textContent = 'Analyze';
 
614
  });
615
  });
616
 
617
+ toggleDistilbert.addEventListener('click', () => switchModel('distilbert'));
618
+ toggleHrm.addEventListener('click', () => switchModel('hrm'));
619
+
620
+ // Load default model (DistilBERT v3)
621
+ switchModel('distilbert');
622
  </script>
623
 
624
  </body>