myemfar commited on
Commit
406f606
·
verified ·
1 Parent(s): 66670a1

Update model card: dataset citations, remove private repo refs

Browse files
Files changed (1) hide show
  1. README.md +30 -26
README.md CHANGED
@@ -11,18 +11,20 @@ tags:
11
  language:
12
  - en
13
  pipeline_tag: text-classification
 
 
 
 
14
  ---
15
 
16
  # distilbert-multitask
17
 
18
- Multi-task DistilBERT classifier for conversational AI pipelines. Trained for use with the [rpchat engine](https://github.com/myemfar/rpchat) a slot-based NPC dialogue pipeline for interactive fiction and games.
19
 
20
- Performs two classification tasks in a single forward pass:
21
-
22
- | Task | Classes | Notes |
23
  |---|---|---|
24
- | Dialogue act | 21 categories | See label list below |
25
- | Manipulation detection | Binary | Detects prompt injection / NPC takeover attempts |
26
 
27
  ## Dialogue Act Labels
28
 
@@ -30,44 +32,46 @@ Performs two classification tasks in a single forward pass:
30
 
31
  ## Model Details
32
 
33
- - **Base model:** `distilbert-base-uncased`
34
  - **Format:** ONNX (CPU inference via `onnxruntime`)
35
  - **Inference time:** ~10–15ms per input on CPU
36
  - **Input:** Single sentence (player utterance in a conversation)
37
- - **Output:** Dialogue act label + manipulation probability
38
 
39
  ## Usage
40
 
41
- The model is consumed automatically by the rpchat engine. On first use, it downloads here if not present locally:
42
-
43
- ```python
44
- from huggingface_hub import snapshot_download
45
- snapshot_download(repo_id="myemfar/distilbert-multitask", local_dir="models/distilbert_multitask")
46
- ```
47
-
48
- Then load via `onnxruntime`:
49
-
50
  ```python
51
- import onnxruntime
52
  import json
 
 
 
53
  from transformers import AutoTokenizer
54
 
55
- session = onnxruntime.InferenceSession("models/distilbert_multitask/model.onnx")
56
- tokenizer = AutoTokenizer.from_pretrained("models/distilbert_multitask")
 
 
57
 
58
- with open("models/distilbert_multitask/label_map_da.json") as f:
59
  labels = {int(k): v for k, v in json.load(f).items()}
60
 
61
  inputs = tokenizer("Where is the tavern?", return_tensors="np")
62
  logits_da, logits_manip = session.run(None, dict(inputs))
63
 
64
- import numpy as np
65
- da_label = labels[int(np.argmax(logits_da))] # "question"
66
  manip_prob = float(1 / (1 + np.exp(-logits_manip[0][0]))) # sigmoid
67
  ```
68
 
69
- ## Training
 
 
 
70
 
71
- Fine-tuned from `distilbert-base-uncased` on a curated dataset of ~2,000 conversational examples across the 21 dialogue act categories, plus a separate manipulation detection dataset targeting prompt injection patterns in NPC conversation contexts.
 
72
 
73
- Training script: `data/train_distilbert.py` in the rpchat repository.
 
 
 
 
 
11
  language:
12
  - en
13
  pipeline_tag: text-classification
14
+ datasets:
15
+ - deepset/prompt-injections
16
+ - hackaprompt/hackaprompt-dataset
17
+ - lakera-ai/gandalf_ignore_instructions
18
  ---
19
 
20
  # distilbert-multitask
21
 
22
+ Multi-task DistilBERT classifier for conversational AI pipelines in interactive fiction and games. Performs two classification tasks in a single forward pass:
23
 
24
+ | Task | Output | Notes |
 
 
25
  |---|---|---|
26
+ | Dialogue act | 21-class label | Classifies player utterance type |
27
+ | Manipulation detection | Binary probability | Detects prompt injection / NPC takeover attempts |
28
 
29
  ## Dialogue Act Labels
30
 
 
32
 
33
  ## Model Details
34
 
35
+ - **Base model:** [distilbert-base-uncased](https://huggingface.co/distilbert/distilbert-base-uncased)
36
  - **Format:** ONNX (CPU inference via `onnxruntime`)
37
  - **Inference time:** ~10–15ms per input on CPU
38
  - **Input:** Single sentence (player utterance in a conversation)
39
+ - **Output:** Dialogue act label + manipulation detection probability
40
 
41
  ## Usage
42
 
 
 
 
 
 
 
 
 
 
43
  ```python
 
44
  import json
45
+ import numpy as np
46
+ import onnxruntime
47
+ from huggingface_hub import snapshot_download
48
  from transformers import AutoTokenizer
49
 
50
+ snapshot_download(repo_id="myemfar/distilbert-multitask", local_dir="./distilbert_multitask")
51
+
52
+ session = onnxruntime.InferenceSession("./distilbert_multitask/model.onnx")
53
+ tokenizer = AutoTokenizer.from_pretrained("./distilbert_multitask")
54
 
55
+ with open("./distilbert_multitask/label_map_da.json") as f:
56
  labels = {int(k): v for k, v in json.load(f).items()}
57
 
58
  inputs = tokenizer("Where is the tavern?", return_tensors="np")
59
  logits_da, logits_manip = session.run(None, dict(inputs))
60
 
61
+ da_label = labels[int(np.argmax(logits_da))] # "question"
 
62
  manip_prob = float(1 / (1 + np.exp(-logits_manip[0][0]))) # sigmoid
63
  ```
64
 
65
+ ## Training Data
66
+
67
+ ### Dialogue Act Classification
68
+ Synthetic training data generated via Claude across 21 conversational categories, curated for interactive fiction and RPG dialogue contexts. Approximately 2,000 labeled examples with targeted augmentation at category boundaries.
69
 
70
+ ### Manipulation / Prompt Injection Detection
71
+ Fine-tuned on a combination of three public datasets plus domain-specific negative examples (in-character RPG dialogue):
72
 
73
+ | Dataset | License | Description |
74
+ |---|---|---|
75
+ | [deepset/prompt-injections](https://huggingface.co/datasets/deepset/prompt-injections) | CC BY 4.0 | Benign queries + prompt injection examples |
76
+ | [hackaprompt/hackaprompt-dataset](https://huggingface.co/datasets/hackaprompt/hackaprompt-dataset) | Apache 2.0 | Red-teaming competition submissions |
77
+ | [lakera-ai/gandalf_ignore_instructions](https://huggingface.co/datasets/lakera-ai/gandalf_ignore_instructions) | CC BY 4.0 | Instruction-override attempts from Lakera's Gandalf challenge |