Spaces:
Running on Zero
Running on Zero
Mihai Maruseac commited on
ZeroGPU Gradio demo for OpenAI Privacy Filter
Browse filesUpdates the README to be informative and adds all the necessary files to showcase the ZeroGPU Gradio demo of the OpenAI Privacy Filter
Signed-off-by: Mihai Maruseac <mihaimaruseac@openai.com>
- README.md +29 -4
- app.py +1319 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
-
title: Privacy Filter
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
|
@@ -9,7 +9,32 @@ python_version: '3.12'
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
| 12 |
-
short_description: Privacy Filter
|
| 13 |
---
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: OpenAI Privacy Filter
|
| 3 |
+
emoji: 🛡️
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
|
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
| 12 |
+
short_description: OpenAI Privacy Filter ZeroGPU demo
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# OpenAI Privacy Filter
|
| 16 |
+
|
| 17 |
+
OpenAI Privacy Filter is a bidirectional token-classification model for personally identifiable information (PII) detection and masking in text. It is intended for high-throughput data sanitization workflows where teams need a model that they can run on-premises that is fast, context-aware, and tunable.
|
| 18 |
+
|
| 19 |
+
OpenAI Privacy Filter is pretrained autoregressively to arrive at a checkpoint with similar architecture to gpt-oss, albeit of a smaller size. We then converted that checkpoint into a bidirectional token classifier over a privacy label taxonomy, and post-trained with a supervised classification loss. (For architecture details about gpt-oss, please see the gpt-oss model card.) Instead of generating text token-by-token, this model labels an input sequence in a single forward pass, then decodes coherent spans with a constrained Viterbi procedure. For each input token, the model predicts a probability distribution over the label taxonomy which consists of 8 output categories described below.
|
| 20 |
+
|
| 21 |
+
Highlights:
|
| 22 |
+
|
| 23 |
+
- Permissive Apache 2.0 license: ideal for experimentation, customization, and commercial deployment.
|
| 24 |
+
- Small size: Runs in a web browser or on a laptop – 1.5B parameters total and 50M active parameters.
|
| 25 |
+
- Fine-tunable: Adapt the model to specific data distributions through easy and data efficient finetuning.
|
| 26 |
+
- Long-context: 128,000-token context window enables processing long text with high throughput and no chunking.
|
| 27 |
+
- Runtime control: configure precision/recall tradeoffs and detected span lengths through preset operating points.
|
| 28 |
+
|
| 29 |
+
## Metadata
|
| 30 |
+
|
| 31 |
+
- Developed by: OpenAI
|
| 32 |
+
- Funded by: OpenAI
|
| 33 |
+
- Shared by: OpenAI
|
| 34 |
+
- Model type: Bidirectional token classification model for privacy span detection
|
| 35 |
+
- Language(s): Primarily English; selected multilingual robustness evaluation reported
|
| 36 |
+
- License: [Apache 2.0](LICENSE)
|
| 37 |
+
|
| 38 |
+
- Source repository: https://github.com/openai/privacy-filter
|
| 39 |
+
- Model weights: https://huggingface.co/openai/privacy-filter
|
| 40 |
+
- Model card: [OpenAI Privacy Filter Model Card](https://cdn.openai.com/pdf/c66281ed-b638-456a-8ce1-97e9f5264a90/OpenAI-Privacy-Filter-Model-Card.pdf)
|
app.py
ADDED
|
@@ -0,0 +1,1319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import functools
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from bisect import bisect_left, bisect_right
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Final
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import spaces
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from safetensors import safe_open
|
| 18 |
+
|
| 19 |
+
import tiktoken
|
| 20 |
+
|
| 21 |
+
from huggingface_hub import snapshot_download
|
| 22 |
+
|
| 23 |
+
MODEL_ROOT = snapshot_download("openai/privacy-filter", allow_patterns=["original/*"])
|
| 24 |
+
MODEL_DIR = Path(MODEL_ROOT) / "original"
|
| 25 |
+
|
| 26 |
+
PRIVACY_FILTER_MODEL_TYPE: Final[str] = "privacy_filter"
|
| 27 |
+
REQUIRED_MODEL_CONFIG_KEYS: Final[tuple[str, ...]] = (
|
| 28 |
+
"model_type",
|
| 29 |
+
"encoding",
|
| 30 |
+
"num_hidden_layers",
|
| 31 |
+
"num_experts",
|
| 32 |
+
"experts_per_token",
|
| 33 |
+
"vocab_size",
|
| 34 |
+
"num_labels",
|
| 35 |
+
"hidden_size",
|
| 36 |
+
"intermediate_size",
|
| 37 |
+
"head_dim",
|
| 38 |
+
"num_attention_heads",
|
| 39 |
+
"num_key_value_heads",
|
| 40 |
+
"sliding_window",
|
| 41 |
+
"bidirectional_context",
|
| 42 |
+
"bidirectional_left_context",
|
| 43 |
+
"bidirectional_right_context",
|
| 44 |
+
"default_n_ctx",
|
| 45 |
+
"initial_context_length",
|
| 46 |
+
"rope_theta",
|
| 47 |
+
"rope_scaling_factor",
|
| 48 |
+
"rope_ntk_alpha",
|
| 49 |
+
"rope_ntk_beta",
|
| 50 |
+
"param_dtype",
|
| 51 |
+
)
|
| 52 |
+
BACKGROUND_CLASS_LABEL: Final[str] = "O"
|
| 53 |
+
BOUNDARY_PREFIXES: Final[tuple[str, ...]] = ("B", "I", "E", "S")
|
| 54 |
+
EMPTY_HIGHLIGHT_PAYLOAD = {"text": "", "entities": []}
|
| 55 |
+
SPAN_CLASS_NAMES: Final[tuple[str, ...]] = (
|
| 56 |
+
BACKGROUND_CLASS_LABEL,
|
| 57 |
+
"account_number",
|
| 58 |
+
"private_address",
|
| 59 |
+
"private_date",
|
| 60 |
+
"private_email",
|
| 61 |
+
"private_person",
|
| 62 |
+
"private_phone",
|
| 63 |
+
"private_url",
|
| 64 |
+
"secret",
|
| 65 |
+
)
|
| 66 |
+
NER_CLASS_NAMES: Final[tuple[str, ...]] = (BACKGROUND_CLASS_LABEL,) + tuple(
|
| 67 |
+
f"{prefix}-{base_label}"
|
| 68 |
+
for base_label in SPAN_CLASS_NAMES
|
| 69 |
+
if base_label != BACKGROUND_CLASS_LABEL
|
| 70 |
+
for prefix in BOUNDARY_PREFIXES
|
| 71 |
+
)
|
| 72 |
+
VITERBI_TRANSITION_BIAS_KEYS: Final[tuple[str, ...]] = (
|
| 73 |
+
"transition_bias_background_stay",
|
| 74 |
+
"transition_bias_background_to_start",
|
| 75 |
+
"transition_bias_inside_to_continue",
|
| 76 |
+
"transition_bias_inside_to_end",
|
| 77 |
+
"transition_bias_end_to_background",
|
| 78 |
+
"transition_bias_end_to_start",
|
| 79 |
+
)
|
| 80 |
+
DEFAULT_VITERBI_CALIBRATION_PRESET: Final[str] = "default"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def validate_model_config_contract(
|
| 84 |
+
checkpoint_config: dict[str, object],
|
| 85 |
+
*,
|
| 86 |
+
context: str,
|
| 87 |
+
) -> None:
|
| 88 |
+
missing = [key for key in REQUIRED_MODEL_CONFIG_KEYS if key not in checkpoint_config]
|
| 89 |
+
if missing:
|
| 90 |
+
raise ValueError(f"{context} is missing required model config keys: {', '.join(missing)}")
|
| 91 |
+
model_type = checkpoint_config.get("model_type")
|
| 92 |
+
if model_type != PRIVACY_FILTER_MODEL_TYPE:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"{context} model_type must be {PRIVACY_FILTER_MODEL_TYPE!r}, got {model_type!r}"
|
| 95 |
+
)
|
| 96 |
+
if checkpoint_config.get("bidirectional_context") is not True:
|
| 97 |
+
raise ValueError(f"{context} must use bidirectional_context=true")
|
| 98 |
+
|
| 99 |
+
raw_left_context = checkpoint_config.get("bidirectional_left_context")
|
| 100 |
+
raw_right_context = checkpoint_config.get("bidirectional_right_context")
|
| 101 |
+
if (
|
| 102 |
+
not isinstance(raw_left_context, int)
|
| 103 |
+
or isinstance(raw_left_context, bool)
|
| 104 |
+
or not isinstance(raw_right_context, int)
|
| 105 |
+
or isinstance(raw_right_context, bool)
|
| 106 |
+
):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"{context} bidirectional context sizes must be integers "
|
| 109 |
+
f"(got {raw_left_context!r}/{raw_right_context!r})"
|
| 110 |
+
)
|
| 111 |
+
left_context = raw_left_context
|
| 112 |
+
right_context = raw_right_context
|
| 113 |
+
if left_context < 0 or right_context < 0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"{context} bidirectional context sizes must be >= 0 "
|
| 116 |
+
f"(got {left_context}/{right_context})"
|
| 117 |
+
)
|
| 118 |
+
if left_context != right_context:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"{context} bidirectional context must be symmetric "
|
| 121 |
+
f"(got left={left_context}, right={right_context})"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
raw_sliding_window = checkpoint_config.get("sliding_window")
|
| 125 |
+
if not isinstance(raw_sliding_window, int) or isinstance(raw_sliding_window, bool):
|
| 126 |
+
raise ValueError(f"{context} sliding_window must be an integer, got {raw_sliding_window!r}")
|
| 127 |
+
sliding_window = raw_sliding_window
|
| 128 |
+
expected_sliding_window = 2 * left_context + 1
|
| 129 |
+
if sliding_window != expected_sliding_window:
|
| 130 |
+
raise ValueError(
|
| 131 |
+
f"{context} sliding_window must equal 2 * bidirectional context + 1 "
|
| 132 |
+
f"(got {sliding_window}, expected {expected_sliding_window})"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
num_labels_raw = checkpoint_config["num_labels"]
|
| 136 |
+
if not isinstance(num_labels_raw, int) or isinstance(num_labels_raw, bool):
|
| 137 |
+
raise ValueError(f"{context} num_labels must be an integer, got {num_labels_raw!r}")
|
| 138 |
+
num_labels = num_labels_raw
|
| 139 |
+
if num_labels != 33:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"{context} must use num_labels=33 for the label space, got {num_labels}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
raw_encoding = checkpoint_config["encoding"]
|
| 145 |
+
if not isinstance(raw_encoding, str) or not raw_encoding.strip():
|
| 146 |
+
raise ValueError(f"{context} encoding must be a non-empty string")
|
| 147 |
+
|
| 148 |
+
raw_n_ctx = checkpoint_config["default_n_ctx"]
|
| 149 |
+
if not isinstance(raw_n_ctx, int) or isinstance(raw_n_ctx, bool):
|
| 150 |
+
raise ValueError(f"{context} default_n_ctx must be a positive integer, got {raw_n_ctx!r}")
|
| 151 |
+
n_ctx = raw_n_ctx
|
| 152 |
+
if n_ctx <= 0:
|
| 153 |
+
raise ValueError(f"{context} default_n_ctx must be positive, got {n_ctx}")
|
| 154 |
+
|
| 155 |
+
raw_param_dtype = checkpoint_config["param_dtype"]
|
| 156 |
+
if raw_param_dtype != "bfloat16":
|
| 157 |
+
raise ValueError(f"{context} param_dtype must be bfloat16, got {raw_param_dtype!r}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def expert_linear(
|
| 161 |
+
x: torch.Tensor,
|
| 162 |
+
weight: torch.Tensor,
|
| 163 |
+
bias: torch.Tensor | None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
num_rows, experts, k_dim = x.shape
|
| 166 |
+
_, _, _, out_dim = weight.shape
|
| 167 |
+
x_bmm = x.reshape(num_rows * experts, 1, k_dim)
|
| 168 |
+
w_bmm = weight.reshape(num_rows * experts, k_dim, out_dim)
|
| 169 |
+
out = torch.bmm(x_bmm, w_bmm).reshape(num_rows, experts, out_dim)
|
| 170 |
+
if bias is not None:
|
| 171 |
+
out = out + bias
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dataclass
|
| 176 |
+
class ModelConfig:
|
| 177 |
+
num_hidden_layers: int
|
| 178 |
+
num_experts: int
|
| 179 |
+
experts_per_token: int
|
| 180 |
+
vocab_size: int
|
| 181 |
+
num_labels: int
|
| 182 |
+
hidden_size: int
|
| 183 |
+
intermediate_size: int
|
| 184 |
+
head_dim: int
|
| 185 |
+
num_attention_heads: int
|
| 186 |
+
num_key_value_heads: int
|
| 187 |
+
bidirectional_context_size: int
|
| 188 |
+
initial_context_length: int
|
| 189 |
+
rope_theta: float
|
| 190 |
+
rope_scaling_factor: float
|
| 191 |
+
rope_ntk_alpha: float
|
| 192 |
+
rope_ntk_beta: float
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def from_checkpoint_config(
|
| 196 |
+
cls,
|
| 197 |
+
checkpoint_config: dict[str, object],
|
| 198 |
+
*,
|
| 199 |
+
context: str,
|
| 200 |
+
) -> "ModelConfig":
|
| 201 |
+
checkpoint_config = dict(checkpoint_config)
|
| 202 |
+
checkpoint_config["bidirectional_context_size"] = checkpoint_config[
|
| 203 |
+
"bidirectional_left_context"
|
| 204 |
+
]
|
| 205 |
+
fields = {field.name: field for field in dataclasses.fields(cls)}
|
| 206 |
+
config_values = {
|
| 207 |
+
key: value for key, value in checkpoint_config.items() if key in fields
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
missing = [
|
| 211 |
+
name
|
| 212 |
+
for name, field in fields.items()
|
| 213 |
+
if field.default is dataclasses.MISSING
|
| 214 |
+
and field.default_factory is dataclasses.MISSING
|
| 215 |
+
and name not in config_values
|
| 216 |
+
]
|
| 217 |
+
if missing:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"{context} is missing required model config fields: {', '.join(missing)}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
return cls(**config_values)
|
| 224 |
+
except TypeError as exc:
|
| 225 |
+
raise ValueError(f"Invalid model config payload at {context}: {exc}") from exc
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class RMSNorm(torch.nn.Module):
|
| 229 |
+
def __init__(
|
| 230 |
+
self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
|
| 231 |
+
) -> None:
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.num_features = num_features
|
| 234 |
+
self.eps = eps
|
| 235 |
+
self.scale = torch.nn.Parameter(
|
| 236 |
+
torch.ones(num_features, device=device, dtype=torch.float32)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
t = x.float()
|
| 241 |
+
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
|
| 242 |
+
return (t * self.scale).to(x.dtype)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def apply_rope(
|
| 246 |
+
x: torch.Tensor,
|
| 247 |
+
cos: torch.Tensor,
|
| 248 |
+
sin: torch.Tensor,
|
| 249 |
+
) -> torch.Tensor:
|
| 250 |
+
cos = cos.unsqueeze(-2).to(x.dtype)
|
| 251 |
+
sin = sin.unsqueeze(-2).to(x.dtype)
|
| 252 |
+
x1 = x[..., ::2]
|
| 253 |
+
x2 = x[..., 1::2]
|
| 254 |
+
out1 = x1 * cos - x2 * sin
|
| 255 |
+
out2 = x2 * cos + x1 * sin
|
| 256 |
+
return torch.stack((out1, out2), dim=-1).reshape(x.shape)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
head_dim: int,
|
| 263 |
+
base: int,
|
| 264 |
+
dtype: torch.dtype,
|
| 265 |
+
*,
|
| 266 |
+
initial_context_length: int = 4096,
|
| 267 |
+
scaling_factor: float = 1.0,
|
| 268 |
+
ntk_alpha: float = 1.0,
|
| 269 |
+
ntk_beta: float = 32.0,
|
| 270 |
+
device: torch.device | None = None,
|
| 271 |
+
) -> None:
|
| 272 |
+
super().__init__()
|
| 273 |
+
self.head_dim = head_dim
|
| 274 |
+
self.base = base
|
| 275 |
+
self.dtype = dtype
|
| 276 |
+
self.initial_context_length = initial_context_length
|
| 277 |
+
self.scaling_factor = scaling_factor
|
| 278 |
+
self.ntk_alpha = ntk_alpha
|
| 279 |
+
self.ntk_beta = ntk_beta
|
| 280 |
+
self.device = device
|
| 281 |
+
max_positions = int(self.initial_context_length * self.scaling_factor)
|
| 282 |
+
max_positions = max(max_positions, self.initial_context_length)
|
| 283 |
+
self.max_position_embeddings = max_positions
|
| 284 |
+
cos, sin = self._compute_cos_sin(self.max_position_embeddings, device=torch.device("cpu"))
|
| 285 |
+
target_device = device or torch.device("cpu")
|
| 286 |
+
self.register_buffer("cos_cache", cos.to(target_device), persistent=False)
|
| 287 |
+
self.register_buffer("sin_cache", sin.to(target_device), persistent=False)
|
| 288 |
+
|
| 289 |
+
def _compute_concentration_and_inv_freq(
|
| 290 |
+
self, device: torch.device | None = None
|
| 291 |
+
) -> tuple[float, torch.Tensor]:
|
| 292 |
+
device = device or self.device
|
| 293 |
+
freq = self.base ** (
|
| 294 |
+
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=device) / self.head_dim
|
| 295 |
+
)
|
| 296 |
+
if self.scaling_factor > 1.0:
|
| 297 |
+
concentration = 0.1 * math.log(self.scaling_factor) + 1.0
|
| 298 |
+
d_half = self.head_dim / 2
|
| 299 |
+
low = (
|
| 300 |
+
d_half
|
| 301 |
+
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
|
| 302 |
+
/ math.log(self.base)
|
| 303 |
+
)
|
| 304 |
+
high = (
|
| 305 |
+
d_half
|
| 306 |
+
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
|
| 307 |
+
/ math.log(self.base)
|
| 308 |
+
)
|
| 309 |
+
interpolation = 1.0 / (self.scaling_factor * freq)
|
| 310 |
+
extrapolation = 1.0 / freq
|
| 311 |
+
ramp = (torch.arange(d_half, dtype=torch.float32, device=freq.device) - low) / (
|
| 312 |
+
high - low
|
| 313 |
+
)
|
| 314 |
+
mask = 1 - ramp.clamp(0, 1)
|
| 315 |
+
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
| 316 |
+
else:
|
| 317 |
+
concentration = 1.0
|
| 318 |
+
inv_freq = 1.0 / freq
|
| 319 |
+
return concentration, inv_freq
|
| 320 |
+
|
| 321 |
+
def _compute_cos_sin(
|
| 322 |
+
self, num_tokens: int, device: torch.device | None = None
|
| 323 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 324 |
+
concentration, inv_freq = self._compute_concentration_and_inv_freq(device=device)
|
| 325 |
+
device = device or self.device
|
| 326 |
+
t = torch.arange(num_tokens, dtype=torch.float32, device=device)
|
| 327 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 328 |
+
cos = freqs.cos() * concentration
|
| 329 |
+
sin = freqs.sin() * concentration
|
| 330 |
+
return cos.to(self.dtype), sin.to(self.dtype)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
query: torch.Tensor,
|
| 335 |
+
key: torch.Tensor,
|
| 336 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 337 |
+
num_tokens = query.shape[0]
|
| 338 |
+
if num_tokens > self.cos_cache.shape[0]:
|
| 339 |
+
cos, sin = self._compute_cos_sin(num_tokens, device=torch.device("cpu"))
|
| 340 |
+
self.cos_cache = cos.to(query.device)
|
| 341 |
+
self.sin_cache = sin.to(query.device)
|
| 342 |
+
if self.cos_cache.device != query.device:
|
| 343 |
+
cos_cache = self.cos_cache.to(query.device)
|
| 344 |
+
sin_cache = self.sin_cache.to(query.device)
|
| 345 |
+
else:
|
| 346 |
+
cos_cache = self.cos_cache
|
| 347 |
+
sin_cache = self.sin_cache
|
| 348 |
+
cos = cos_cache[:num_tokens]
|
| 349 |
+
sin = sin_cache[:num_tokens]
|
| 350 |
+
|
| 351 |
+
query_shape = query.shape
|
| 352 |
+
query = query.view(num_tokens, -1, self.head_dim)
|
| 353 |
+
query = apply_rope(query, cos, sin)
|
| 354 |
+
query = query.reshape(query_shape)
|
| 355 |
+
|
| 356 |
+
key_shape = key.shape
|
| 357 |
+
key = key.view(num_tokens, -1, self.head_dim)
|
| 358 |
+
key = apply_rope(key, cos, sin)
|
| 359 |
+
key = key.reshape(key_shape)
|
| 360 |
+
return query, key
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def sdpa(
|
| 364 |
+
Q: torch.Tensor,
|
| 365 |
+
K: torch.Tensor,
|
| 366 |
+
V: torch.Tensor,
|
| 367 |
+
S: torch.Tensor,
|
| 368 |
+
sm_scale: float,
|
| 369 |
+
context_size: int,
|
| 370 |
+
) -> torch.Tensor:
|
| 371 |
+
num_tokens, num_heads, q_mult, head_dim = Q.shape
|
| 372 |
+
window = 2 * context_size + 1
|
| 373 |
+
Kp = F.pad(K, (0, 0, 0, 0, context_size, context_size))
|
| 374 |
+
Vp = F.pad(V, (0, 0, 0, 0, context_size, context_size))
|
| 375 |
+
Kwin = Kp.unfold(0, window, 1).permute(0, 3, 1, 2)
|
| 376 |
+
Vwin = Vp.unfold(0, window, 1).permute(0, 3, 1, 2)
|
| 377 |
+
idx = torch.arange(window, device=Q.device) - context_size
|
| 378 |
+
pos = torch.arange(num_tokens, device=Q.device)[:, None] + idx[None, :]
|
| 379 |
+
valid = (pos >= 0) & (pos < num_tokens)
|
| 380 |
+
scores = torch.einsum("nhqd,nwhd->nhqw", Q, Kwin).float()
|
| 381 |
+
scores *= sm_scale
|
| 382 |
+
scores = scores.masked_fill(~valid[:, None, None, :], -float("inf"))
|
| 383 |
+
sink_scores = (S * math.log(2.0)).reshape(num_heads, q_mult)
|
| 384 |
+
sink_scores = sink_scores[None, :, :, None].expand(num_tokens, -1, -1, 1)
|
| 385 |
+
scores = torch.cat([scores, sink_scores], dim=-1)
|
| 386 |
+
weights = torch.softmax(scores, dim=-1)[..., :-1].to(V.dtype)
|
| 387 |
+
attn = torch.einsum("nhqw,nwhd->nhqd", weights, Vwin)
|
| 388 |
+
return attn.reshape(num_tokens, -1)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class AttentionBlock(torch.nn.Module):
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
config: ModelConfig,
|
| 395 |
+
device: torch.device | None = None,
|
| 396 |
+
) -> None:
|
| 397 |
+
super().__init__()
|
| 398 |
+
param_dtype = torch.bfloat16
|
| 399 |
+
self.head_dim = config.head_dim
|
| 400 |
+
self.num_attention_heads = config.num_attention_heads
|
| 401 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 402 |
+
self.bidirectional_context_size = int(config.bidirectional_context_size)
|
| 403 |
+
self.sinks = torch.nn.Parameter(
|
| 404 |
+
torch.empty(config.num_attention_heads, device=device, dtype=torch.float32)
|
| 405 |
+
)
|
| 406 |
+
self.norm = RMSNorm(config.hidden_size, device=device)
|
| 407 |
+
qkv_dim = config.head_dim * (config.num_attention_heads + 2 * config.num_key_value_heads)
|
| 408 |
+
self.qkv = torch.nn.Linear(config.hidden_size, qkv_dim, device=device, dtype=param_dtype)
|
| 409 |
+
self.out = torch.nn.Linear(
|
| 410 |
+
config.head_dim * config.num_attention_heads,
|
| 411 |
+
config.hidden_size,
|
| 412 |
+
device=device,
|
| 413 |
+
dtype=param_dtype,
|
| 414 |
+
)
|
| 415 |
+
self.qk_scale = 1 / math.sqrt(math.sqrt(config.head_dim))
|
| 416 |
+
self.sm_scale = 1.0
|
| 417 |
+
self.rope = RotaryEmbedding(
|
| 418 |
+
config.head_dim,
|
| 419 |
+
int(config.rope_theta),
|
| 420 |
+
torch.float32,
|
| 421 |
+
initial_context_length=config.initial_context_length,
|
| 422 |
+
scaling_factor=config.rope_scaling_factor,
|
| 423 |
+
ntk_alpha=config.rope_ntk_alpha,
|
| 424 |
+
ntk_beta=config.rope_ntk_beta,
|
| 425 |
+
device=device,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def forward(
|
| 429 |
+
self,
|
| 430 |
+
x: torch.Tensor,
|
| 431 |
+
) -> torch.Tensor:
|
| 432 |
+
t = self.norm(x)
|
| 433 |
+
if t.dtype != self.qkv.weight.dtype:
|
| 434 |
+
t = t.to(self.qkv.weight.dtype)
|
| 435 |
+
qkv = F.linear(t, self.qkv.weight, self.qkv.bias)
|
| 436 |
+
query = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
|
| 437 |
+
key = qkv[
|
| 438 |
+
:,
|
| 439 |
+
self.num_attention_heads * self.head_dim : (
|
| 440 |
+
self.num_attention_heads + self.num_key_value_heads
|
| 441 |
+
)
|
| 442 |
+
* self.head_dim,
|
| 443 |
+
].contiguous()
|
| 444 |
+
value = qkv[
|
| 445 |
+
:,
|
| 446 |
+
(self.num_attention_heads + self.num_key_value_heads) * self.head_dim : (
|
| 447 |
+
self.num_attention_heads + 2 * self.num_key_value_heads
|
| 448 |
+
)
|
| 449 |
+
* self.head_dim,
|
| 450 |
+
].contiguous()
|
| 451 |
+
|
| 452 |
+
query, key = self.rope(query, key)
|
| 453 |
+
query = query * self.qk_scale
|
| 454 |
+
key = key * self.qk_scale
|
| 455 |
+
sinks = self.sinks
|
| 456 |
+
num_tokens = query.shape[0]
|
| 457 |
+
query = query.view(
|
| 458 |
+
num_tokens,
|
| 459 |
+
self.num_key_value_heads,
|
| 460 |
+
self.num_attention_heads // self.num_key_value_heads,
|
| 461 |
+
self.head_dim,
|
| 462 |
+
)
|
| 463 |
+
key = key.view(num_tokens, self.num_key_value_heads, self.head_dim)
|
| 464 |
+
value = value.view(num_tokens, self.num_key_value_heads, self.head_dim)
|
| 465 |
+
attn_out = sdpa(
|
| 466 |
+
query,
|
| 467 |
+
key,
|
| 468 |
+
value,
|
| 469 |
+
sinks,
|
| 470 |
+
self.sm_scale,
|
| 471 |
+
self.bidirectional_context_size,
|
| 472 |
+
)
|
| 473 |
+
if attn_out.dtype != self.out.weight.dtype:
|
| 474 |
+
attn_out = attn_out.to(self.out.weight.dtype)
|
| 475 |
+
proj_bias = self.out.bias
|
| 476 |
+
proj = F.linear(attn_out, self.out.weight, proj_bias)
|
| 477 |
+
return x + proj.to(x.dtype)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def swiglu(
|
| 481 |
+
x: torch.Tensor,
|
| 482 |
+
alpha: float = 1.702,
|
| 483 |
+
limit: float = 7.0,
|
| 484 |
+
) -> torch.Tensor:
|
| 485 |
+
x_glu, x_linear = x.chunk(2, dim=-1)
|
| 486 |
+
x_glu = x_glu.clamp(min=None, max=limit)
|
| 487 |
+
x_linear = x_linear.clamp(min=-limit, max=limit)
|
| 488 |
+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
| 489 |
+
return out_glu * (x_linear + 1)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class MLPBlock(torch.nn.Module):
|
| 493 |
+
def __init__(
|
| 494 |
+
self,
|
| 495 |
+
config: ModelConfig,
|
| 496 |
+
device: torch.device | None = None,
|
| 497 |
+
) -> None:
|
| 498 |
+
super().__init__()
|
| 499 |
+
param_dtype = torch.bfloat16
|
| 500 |
+
self.num_experts = config.num_experts
|
| 501 |
+
self.experts_per_token = config.experts_per_token
|
| 502 |
+
self.swiglu_limit = 7.0
|
| 503 |
+
self.norm = RMSNorm(config.hidden_size, device=device)
|
| 504 |
+
self.gate = torch.nn.Linear(
|
| 505 |
+
config.hidden_size, config.num_experts, device=device, dtype=param_dtype
|
| 506 |
+
)
|
| 507 |
+
self.mlp1_weight = torch.nn.Parameter(
|
| 508 |
+
torch.empty(
|
| 509 |
+
(config.num_experts, config.hidden_size, config.intermediate_size * 2),
|
| 510 |
+
device=device,
|
| 511 |
+
dtype=param_dtype,
|
| 512 |
+
)
|
| 513 |
+
)
|
| 514 |
+
self.mlp1_bias = torch.nn.Parameter(
|
| 515 |
+
torch.empty(
|
| 516 |
+
(config.num_experts, config.intermediate_size * 2),
|
| 517 |
+
device=device,
|
| 518 |
+
dtype=param_dtype,
|
| 519 |
+
)
|
| 520 |
+
)
|
| 521 |
+
self.mlp2_weight = torch.nn.Parameter(
|
| 522 |
+
torch.empty(
|
| 523 |
+
(config.num_experts, config.intermediate_size, config.hidden_size),
|
| 524 |
+
device=device,
|
| 525 |
+
dtype=param_dtype,
|
| 526 |
+
)
|
| 527 |
+
)
|
| 528 |
+
self.mlp2_bias = torch.nn.Parameter(
|
| 529 |
+
torch.empty(
|
| 530 |
+
(config.num_experts, config.hidden_size),
|
| 531 |
+
device=device,
|
| 532 |
+
dtype=param_dtype,
|
| 533 |
+
)
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 537 |
+
t = self.norm(x)
|
| 538 |
+
gate_scores = F.linear(t.float(), self.gate.weight.float(), self.gate.bias.float())
|
| 539 |
+
experts = torch.topk(gate_scores, k=self.experts_per_token, dim=-1, sorted=True)
|
| 540 |
+
expert_weights = torch.softmax(experts.values, dim=-1) / self.experts_per_token
|
| 541 |
+
|
| 542 |
+
expert_indices = experts.indices
|
| 543 |
+
experts_per_token_eff = self.experts_per_token
|
| 544 |
+
|
| 545 |
+
def _moe_chunk(
|
| 546 |
+
t_chunk: torch.Tensor,
|
| 547 |
+
expert_indices_chunk: torch.Tensor,
|
| 548 |
+
expert_weights_chunk: torch.Tensor,
|
| 549 |
+
) -> torch.Tensor:
|
| 550 |
+
mlp1_weight = self.mlp1_weight[expert_indices_chunk].float()
|
| 551 |
+
mlp1_bias = self.mlp1_bias[expert_indices_chunk].float()
|
| 552 |
+
t_expanded = t_chunk.float().unsqueeze(1).expand(-1, expert_indices_chunk.shape[1], -1)
|
| 553 |
+
out = expert_linear(
|
| 554 |
+
t_expanded,
|
| 555 |
+
mlp1_weight,
|
| 556 |
+
mlp1_bias,
|
| 557 |
+
)
|
| 558 |
+
out = swiglu(out, limit=self.swiglu_limit)
|
| 559 |
+
mlp2_weight = self.mlp2_weight[expert_indices_chunk].float()
|
| 560 |
+
mlp2_bias = self.mlp2_bias[expert_indices_chunk].float()
|
| 561 |
+
out = expert_linear(
|
| 562 |
+
out.float(),
|
| 563 |
+
mlp2_weight,
|
| 564 |
+
mlp2_bias,
|
| 565 |
+
)
|
| 566 |
+
if out.dtype != expert_weights_chunk.dtype:
|
| 567 |
+
out = out.to(expert_weights_chunk.dtype)
|
| 568 |
+
out = torch.einsum("bec,be->bc", out, expert_weights_chunk)
|
| 569 |
+
out = out * experts_per_token_eff
|
| 570 |
+
return out.to(x.dtype)
|
| 571 |
+
|
| 572 |
+
torch_ops_chunk_size = 32
|
| 573 |
+
if t.shape[0] > torch_ops_chunk_size:
|
| 574 |
+
chunks = []
|
| 575 |
+
for start in range(0, t.shape[0], torch_ops_chunk_size):
|
| 576 |
+
end = start + torch_ops_chunk_size
|
| 577 |
+
chunks.append(
|
| 578 |
+
_moe_chunk(
|
| 579 |
+
t[start:end],
|
| 580 |
+
expert_indices[start:end],
|
| 581 |
+
expert_weights[start:end],
|
| 582 |
+
)
|
| 583 |
+
)
|
| 584 |
+
t = torch.cat(chunks, dim=0)
|
| 585 |
+
else:
|
| 586 |
+
t = _moe_chunk(t, expert_indices, expert_weights)
|
| 587 |
+
return x + t
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class TransformerBlock(torch.nn.Module):
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
config: ModelConfig,
|
| 594 |
+
device: torch.device | None = None,
|
| 595 |
+
) -> None:
|
| 596 |
+
super().__init__()
|
| 597 |
+
self.attn = AttentionBlock(config, device=device)
|
| 598 |
+
self.mlp = MLPBlock(config, device=device)
|
| 599 |
+
|
| 600 |
+
def forward(
|
| 601 |
+
self,
|
| 602 |
+
x: torch.Tensor,
|
| 603 |
+
) -> torch.Tensor:
|
| 604 |
+
x = self.attn(x)
|
| 605 |
+
return self.mlp(x)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class Checkpoint:
|
| 609 |
+
@staticmethod
|
| 610 |
+
def build_param_name_map(
|
| 611 |
+
num_hidden_layers: int,
|
| 612 |
+
) -> dict[str, str]:
|
| 613 |
+
return (
|
| 614 |
+
{
|
| 615 |
+
f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.swiglu.bias"
|
| 616 |
+
for n in range(num_hidden_layers)
|
| 617 |
+
}
|
| 618 |
+
| {
|
| 619 |
+
f"block.{n}.mlp.mlp1_weight": f"block.{n}.mlp.swiglu.weight"
|
| 620 |
+
for n in range(num_hidden_layers)
|
| 621 |
+
}
|
| 622 |
+
| {
|
| 623 |
+
f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.out.bias"
|
| 624 |
+
for n in range(num_hidden_layers)
|
| 625 |
+
}
|
| 626 |
+
| {
|
| 627 |
+
f"block.{n}.mlp.mlp2_weight": f"block.{n}.mlp.out.weight"
|
| 628 |
+
for n in range(num_hidden_layers)
|
| 629 |
+
}
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
def __init__(self, path: str, device: torch.device, num_hidden_layers: int) -> None:
|
| 633 |
+
self.param_name_map = self.build_param_name_map(num_hidden_layers)
|
| 634 |
+
self.device_str = device.type if device.index is None else f"{device.type}:{device.index}"
|
| 635 |
+
safetensor_files = [
|
| 636 |
+
os.path.join(path, filename)
|
| 637 |
+
for filename in os.listdir(path)
|
| 638 |
+
if filename.endswith(".safetensors")
|
| 639 |
+
]
|
| 640 |
+
tensor_name_to_file: dict[str, str] = {}
|
| 641 |
+
for safetensor_file in safetensor_files:
|
| 642 |
+
with safe_open(safetensor_file, framework="pt", device=self.device_str) as handle:
|
| 643 |
+
for key in handle.keys():
|
| 644 |
+
prior_file = tensor_name_to_file.get(key)
|
| 645 |
+
if prior_file is not None:
|
| 646 |
+
raise ValueError(
|
| 647 |
+
"Duplicate tensor name in checkpoint shards: "
|
| 648 |
+
f"{key!r} appears in {prior_file!r} and {safetensor_file!r}"
|
| 649 |
+
)
|
| 650 |
+
tensor_name_to_file[key] = safetensor_file
|
| 651 |
+
self.tensor_name_to_file = tensor_name_to_file
|
| 652 |
+
|
| 653 |
+
def get(self, name: str) -> torch.Tensor:
|
| 654 |
+
mapped = self.param_name_map.get(name, name)
|
| 655 |
+
return self._get_tensor(mapped)
|
| 656 |
+
|
| 657 |
+
def _get_tensor(self, name: str) -> torch.Tensor:
|
| 658 |
+
if name not in self.tensor_name_to_file:
|
| 659 |
+
raise KeyError(f"Tensor {name!r} not found in checkpoint")
|
| 660 |
+
with safe_open(
|
| 661 |
+
self.tensor_name_to_file[name], framework="pt", device=self.device_str
|
| 662 |
+
) as handle:
|
| 663 |
+
return handle.get_tensor(name)
|
| 664 |
+
|
| 665 |
+
class Transformer(torch.nn.Module):
|
| 666 |
+
def __init__(self, config: ModelConfig, device: torch.device) -> None:
|
| 667 |
+
super().__init__()
|
| 668 |
+
param_dtype = torch.bfloat16
|
| 669 |
+
self.embedding = torch.nn.Embedding(
|
| 670 |
+
config.vocab_size, config.hidden_size, device=device, dtype=param_dtype
|
| 671 |
+
)
|
| 672 |
+
self.block = torch.nn.ModuleList(
|
| 673 |
+
[
|
| 674 |
+
TransformerBlock(config, device=device)
|
| 675 |
+
for _ in range(config.num_hidden_layers)
|
| 676 |
+
]
|
| 677 |
+
)
|
| 678 |
+
self.norm = RMSNorm(config.hidden_size, device=device)
|
| 679 |
+
self.unembedding = torch.nn.Linear(
|
| 680 |
+
config.hidden_size,
|
| 681 |
+
config.num_labels,
|
| 682 |
+
bias=False,
|
| 683 |
+
device=device,
|
| 684 |
+
dtype=param_dtype,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
def forward(
|
| 688 |
+
self,
|
| 689 |
+
token_ids: torch.Tensor,
|
| 690 |
+
) -> torch.Tensor:
|
| 691 |
+
x = self.embedding(token_ids)
|
| 692 |
+
for block in self.block:
|
| 693 |
+
x = block(x)
|
| 694 |
+
x = self.norm(x)
|
| 695 |
+
x = F.linear(x, self.unembedding.weight, None)
|
| 696 |
+
return x
|
| 697 |
+
|
| 698 |
+
@classmethod
|
| 699 |
+
def from_checkpoint(
|
| 700 |
+
cls,
|
| 701 |
+
checkpoint_dir: str,
|
| 702 |
+
*,
|
| 703 |
+
device: torch.device,
|
| 704 |
+
) -> "Transformer":
|
| 705 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 706 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 707 |
+
torch.set_float32_matmul_precision("highest")
|
| 708 |
+
config_path = Path(checkpoint_dir) / "config.json"
|
| 709 |
+
with config_path.open("r", encoding="utf-8") as handle:
|
| 710 |
+
checkpoint_config = json.load(handle)
|
| 711 |
+
if not isinstance(checkpoint_config, dict):
|
| 712 |
+
raise ValueError(f"Invalid checkpoint config payload at {config_path}")
|
| 713 |
+
validate_model_config_contract(
|
| 714 |
+
checkpoint_config,
|
| 715 |
+
context=str(config_path),
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
config = ModelConfig.from_checkpoint_config(
|
| 719 |
+
checkpoint_config,
|
| 720 |
+
context=str(config_path),
|
| 721 |
+
)
|
| 722 |
+
checkpoint = Checkpoint(
|
| 723 |
+
checkpoint_dir,
|
| 724 |
+
device,
|
| 725 |
+
num_hidden_layers=config.num_hidden_layers,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
model = cls(config=config, device=device)
|
| 729 |
+
model.eval()
|
| 730 |
+
|
| 731 |
+
for name, param in model.named_parameters():
|
| 732 |
+
loaded_tensor = checkpoint.get(name)
|
| 733 |
+
if param.data.shape != loaded_tensor.shape:
|
| 734 |
+
raise ValueError(
|
| 735 |
+
f"Tensor shape mismatch for {name!r}: expected {tuple(param.data.shape)}, "
|
| 736 |
+
f"got {tuple(loaded_tensor.shape)}"
|
| 737 |
+
)
|
| 738 |
+
param.data.copy_(loaded_tensor)
|
| 739 |
+
|
| 740 |
+
return model
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
@dataclass(frozen=True)
|
| 744 |
+
class LabelInfo:
|
| 745 |
+
boundary_label_lookup: dict[str, dict[str, int]]
|
| 746 |
+
token_to_span_label: dict[int, int]
|
| 747 |
+
token_boundary_tags: dict[int, str | None]
|
| 748 |
+
span_class_names: tuple[str, ...]
|
| 749 |
+
span_label_lookup: dict[str, int]
|
| 750 |
+
background_token_label: int
|
| 751 |
+
background_span_label: int
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def labels_to_spans(
|
| 755 |
+
labels_by_index: dict[int, int], label_info: LabelInfo
|
| 756 |
+
) -> list[tuple[int, int, int]]:
|
| 757 |
+
spans: list[tuple[int, int, int]] = []
|
| 758 |
+
current_label: int | None = None
|
| 759 |
+
start_idx: int | None = None
|
| 760 |
+
previous_idx: int | None = None
|
| 761 |
+
background_span_label = label_info.background_span_label
|
| 762 |
+
|
| 763 |
+
for token_idx in sorted(labels_by_index):
|
| 764 |
+
label_id = labels_by_index[token_idx]
|
| 765 |
+
span_label = label_info.token_to_span_label.get(label_id)
|
| 766 |
+
boundary_tag = label_info.token_boundary_tags.get(label_id)
|
| 767 |
+
|
| 768 |
+
if previous_idx is not None and token_idx != previous_idx + 1:
|
| 769 |
+
if current_label is not None and start_idx is not None:
|
| 770 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 771 |
+
current_label = None
|
| 772 |
+
start_idx = None
|
| 773 |
+
|
| 774 |
+
if span_label is None:
|
| 775 |
+
previous_idx = token_idx
|
| 776 |
+
continue
|
| 777 |
+
|
| 778 |
+
if span_label == background_span_label:
|
| 779 |
+
if current_label is not None and start_idx is not None:
|
| 780 |
+
spans.append((current_label, start_idx, token_idx))
|
| 781 |
+
current_label = None
|
| 782 |
+
start_idx = None
|
| 783 |
+
previous_idx = token_idx
|
| 784 |
+
continue
|
| 785 |
+
|
| 786 |
+
if boundary_tag == "S":
|
| 787 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 788 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 789 |
+
spans.append((span_label, token_idx, token_idx + 1))
|
| 790 |
+
current_label = None
|
| 791 |
+
start_idx = None
|
| 792 |
+
elif boundary_tag == "B":
|
| 793 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 794 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 795 |
+
current_label = span_label
|
| 796 |
+
start_idx = token_idx
|
| 797 |
+
elif boundary_tag == "I":
|
| 798 |
+
if current_label is None or current_label != span_label:
|
| 799 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 800 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 801 |
+
current_label = span_label
|
| 802 |
+
start_idx = token_idx
|
| 803 |
+
elif boundary_tag == "E":
|
| 804 |
+
if current_label is None or current_label != span_label or start_idx is None:
|
| 805 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 806 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 807 |
+
spans.append((span_label, token_idx, token_idx + 1))
|
| 808 |
+
current_label = None
|
| 809 |
+
start_idx = None
|
| 810 |
+
else:
|
| 811 |
+
spans.append((current_label, start_idx, token_idx + 1))
|
| 812 |
+
current_label = None
|
| 813 |
+
start_idx = None
|
| 814 |
+
else:
|
| 815 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 816 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 817 |
+
current_label = None
|
| 818 |
+
start_idx = None
|
| 819 |
+
|
| 820 |
+
previous_idx = token_idx
|
| 821 |
+
|
| 822 |
+
if current_label is not None and start_idx is not None and previous_idx is not None:
|
| 823 |
+
spans.append((current_label, start_idx, previous_idx + 1))
|
| 824 |
+
return spans
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def token_spans_to_char_spans(
|
| 828 |
+
spans: Sequence[tuple[int, int, int]],
|
| 829 |
+
char_starts: Sequence[int],
|
| 830 |
+
char_ends: Sequence[int],
|
| 831 |
+
) -> list[tuple[int, int, int]]:
|
| 832 |
+
converted: list[tuple[int, int, int]] = []
|
| 833 |
+
for label_idx, token_start, token_end in spans:
|
| 834 |
+
if not (0 <= token_start < token_end <= len(char_starts)):
|
| 835 |
+
continue
|
| 836 |
+
char_start = char_starts[token_start]
|
| 837 |
+
char_end = char_ends[token_end - 1]
|
| 838 |
+
if char_end <= char_start:
|
| 839 |
+
continue
|
| 840 |
+
converted.append((label_idx, char_start, char_end))
|
| 841 |
+
return converted
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def trim_char_spans_whitespace(
|
| 845 |
+
spans: Sequence[tuple[int, int, int]],
|
| 846 |
+
text: str,
|
| 847 |
+
) -> list[tuple[int, int, int]]:
|
| 848 |
+
trimmed: list[tuple[int, int, int]] = []
|
| 849 |
+
for label_idx, start, end in spans:
|
| 850 |
+
if not (0 <= start < end <= len(text)):
|
| 851 |
+
continue
|
| 852 |
+
while start < end and text[start].isspace():
|
| 853 |
+
start += 1
|
| 854 |
+
while end > start and text[end - 1].isspace():
|
| 855 |
+
end -= 1
|
| 856 |
+
if end > start:
|
| 857 |
+
trimmed.append((label_idx, start, end))
|
| 858 |
+
return trimmed
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
@dataclass(frozen=True)
|
| 862 |
+
class InferenceRuntime:
|
| 863 |
+
model: Transformer
|
| 864 |
+
encoding: tiktoken.Encoding
|
| 865 |
+
label_info: LabelInfo
|
| 866 |
+
device: torch.device
|
| 867 |
+
n_ctx: int
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
@functools.lru_cache(maxsize=1)
|
| 871 |
+
def get_viterbi_transition_biases() -> dict[str, float]:
|
| 872 |
+
calibration_path = MODEL_DIR / "viterbi_calibration.json"
|
| 873 |
+
default_biases = {key: 0.0 for key in VITERBI_TRANSITION_BIAS_KEYS}
|
| 874 |
+
if not calibration_path.is_file():
|
| 875 |
+
return default_biases
|
| 876 |
+
|
| 877 |
+
payload = json.loads(calibration_path.read_text(encoding="utf-8"))
|
| 878 |
+
if not isinstance(payload, dict):
|
| 879 |
+
raise ValueError(f"Invalid Viterbi calibration payload at {calibration_path}")
|
| 880 |
+
|
| 881 |
+
raw_biases: object = payload
|
| 882 |
+
operating_points = payload.get("operating_points")
|
| 883 |
+
if operating_points is not None:
|
| 884 |
+
if not isinstance(operating_points, dict):
|
| 885 |
+
raise ValueError(f"Invalid operating_points payload at {calibration_path}")
|
| 886 |
+
preset_entry = operating_points.get(DEFAULT_VITERBI_CALIBRATION_PRESET)
|
| 887 |
+
if not isinstance(preset_entry, dict):
|
| 888 |
+
raise ValueError(
|
| 889 |
+
f"Missing operating_points.{DEFAULT_VITERBI_CALIBRATION_PRESET!s} "
|
| 890 |
+
f"in {calibration_path}"
|
| 891 |
+
)
|
| 892 |
+
raw_biases = preset_entry.get("biases")
|
| 893 |
+
|
| 894 |
+
if not isinstance(raw_biases, dict):
|
| 895 |
+
raise ValueError(f"Invalid Viterbi bias payload at {calibration_path}")
|
| 896 |
+
|
| 897 |
+
resolved_biases: dict[str, float] = {}
|
| 898 |
+
for key in VITERBI_TRANSITION_BIAS_KEYS:
|
| 899 |
+
raw_value = raw_biases.get(key)
|
| 900 |
+
if isinstance(raw_value, bool) or not isinstance(raw_value, (int, float)):
|
| 901 |
+
raise ValueError(f"Missing or invalid {key!r} in {calibration_path}")
|
| 902 |
+
resolved_biases[key] = float(raw_value)
|
| 903 |
+
return resolved_biases
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
@functools.lru_cache(maxsize=1)
|
| 907 |
+
def get_runtime() -> InferenceRuntime:
|
| 908 |
+
checkpoint = MODEL_DIR
|
| 909 |
+
if not checkpoint.exists() or not checkpoint.is_dir():
|
| 910 |
+
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint}")
|
| 911 |
+
if not any(checkpoint.glob("*.safetensors")):
|
| 912 |
+
raise FileNotFoundError(f"Checkpoint directory has no .safetensors files: {checkpoint}")
|
| 913 |
+
if not torch.cuda.is_available():
|
| 914 |
+
raise RuntimeError("CUDA is not available")
|
| 915 |
+
config_path = checkpoint / "config.json"
|
| 916 |
+
checkpoint_config = json.loads(config_path.read_text(encoding="utf-8"))
|
| 917 |
+
if not isinstance(checkpoint_config, dict):
|
| 918 |
+
raise ValueError(f"Invalid checkpoint config payload at {config_path}")
|
| 919 |
+
validate_model_config_contract(
|
| 920 |
+
checkpoint_config,
|
| 921 |
+
context=str(config_path),
|
| 922 |
+
)
|
| 923 |
+
ner_class_names = NER_CLASS_NAMES
|
| 924 |
+
device = torch.device("cuda")
|
| 925 |
+
n_ctx = int(checkpoint_config["default_n_ctx"])
|
| 926 |
+
|
| 927 |
+
encoding = tiktoken.get_encoding(str(checkpoint_config["encoding"]).strip())
|
| 928 |
+
span_class_names: list[str] = [BACKGROUND_CLASS_LABEL]
|
| 929 |
+
span_label_lookup: dict[str, int] = {BACKGROUND_CLASS_LABEL: 0}
|
| 930 |
+
boundary_label_lookup: dict[str, dict[str, int]] = {}
|
| 931 |
+
token_to_span_label: dict[int, int] = {}
|
| 932 |
+
token_boundary_tags: dict[int, str | None] = {}
|
| 933 |
+
background_idx: int | None = None
|
| 934 |
+
for idx, name in enumerate(ner_class_names):
|
| 935 |
+
if name == BACKGROUND_CLASS_LABEL:
|
| 936 |
+
background_idx = idx
|
| 937 |
+
token_to_span_label[idx] = span_label_lookup[BACKGROUND_CLASS_LABEL]
|
| 938 |
+
token_boundary_tags[idx] = None
|
| 939 |
+
continue
|
| 940 |
+
boundary, base_label = name.split("-", 1)
|
| 941 |
+
span_idx = span_label_lookup.get(base_label)
|
| 942 |
+
if span_idx is None:
|
| 943 |
+
span_idx = len(span_class_names)
|
| 944 |
+
span_class_names.append(base_label)
|
| 945 |
+
span_label_lookup[base_label] = span_idx
|
| 946 |
+
token_to_span_label[idx] = span_idx
|
| 947 |
+
token_boundary_tags[idx] = boundary
|
| 948 |
+
boundary_label_lookup.setdefault(base_label, {})[boundary] = idx
|
| 949 |
+
if background_idx is None:
|
| 950 |
+
raise ValueError("Class names must include background label 'O'")
|
| 951 |
+
for base_label, mapping in boundary_label_lookup.items():
|
| 952 |
+
missing = set(BOUNDARY_PREFIXES) - set(mapping)
|
| 953 |
+
if missing:
|
| 954 |
+
raise ValueError(
|
| 955 |
+
f"Missing boundary classes {sorted(missing)} for base label {base_label}"
|
| 956 |
+
)
|
| 957 |
+
label_info = LabelInfo(
|
| 958 |
+
boundary_label_lookup={key: dict(value) for key, value in boundary_label_lookup.items()},
|
| 959 |
+
token_to_span_label=dict(token_to_span_label),
|
| 960 |
+
token_boundary_tags=dict(token_boundary_tags),
|
| 961 |
+
span_class_names=tuple(span_class_names),
|
| 962 |
+
span_label_lookup=dict(span_label_lookup),
|
| 963 |
+
background_token_label=background_idx,
|
| 964 |
+
background_span_label=span_label_lookup[BACKGROUND_CLASS_LABEL],
|
| 965 |
+
)
|
| 966 |
+
model = Transformer.from_checkpoint(
|
| 967 |
+
checkpoint,
|
| 968 |
+
device=device,
|
| 969 |
+
)
|
| 970 |
+
return InferenceRuntime(
|
| 971 |
+
model=model,
|
| 972 |
+
encoding=encoding,
|
| 973 |
+
label_info=label_info,
|
| 974 |
+
device=device,
|
| 975 |
+
n_ctx=n_ctx,
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class Decoder:
|
| 980 |
+
def __init__(self, label_info: LabelInfo) -> None:
|
| 981 |
+
self.label_info = label_info
|
| 982 |
+
num_classes = len(label_info.token_to_span_label)
|
| 983 |
+
self._start_scores = torch.full((num_classes,), -1e9, dtype=torch.float32)
|
| 984 |
+
self._end_scores = torch.full((num_classes,), -1e9, dtype=torch.float32)
|
| 985 |
+
self._transition_scores = torch.full((num_classes, num_classes), -1e9, dtype=torch.float32)
|
| 986 |
+
transition_biases = get_viterbi_transition_biases()
|
| 987 |
+
|
| 988 |
+
background_token_idx = label_info.background_token_label
|
| 989 |
+
background_span_idx = label_info.background_span_label
|
| 990 |
+
token_boundary_tags = label_info.token_boundary_tags
|
| 991 |
+
token_to_span_label = label_info.token_to_span_label
|
| 992 |
+
|
| 993 |
+
for idx in range(num_classes):
|
| 994 |
+
tag = token_boundary_tags.get(idx)
|
| 995 |
+
span_label = token_to_span_label.get(idx)
|
| 996 |
+
if tag in {"B", "S"} or idx == background_token_idx:
|
| 997 |
+
self._start_scores[idx] = 0.0
|
| 998 |
+
if tag in {"E", "S"} or idx == background_token_idx:
|
| 999 |
+
self._end_scores[idx] = 0.0
|
| 1000 |
+
|
| 1001 |
+
for next_idx in range(num_classes):
|
| 1002 |
+
next_tag = token_boundary_tags.get(next_idx)
|
| 1003 |
+
next_span_label = token_to_span_label.get(next_idx)
|
| 1004 |
+
if self._is_valid_transition(
|
| 1005 |
+
prev_tag=tag,
|
| 1006 |
+
prev_span=span_label,
|
| 1007 |
+
next_tag=next_tag,
|
| 1008 |
+
next_span=next_span_label,
|
| 1009 |
+
background_token_idx=background_token_idx,
|
| 1010 |
+
background_span_idx=background_span_idx,
|
| 1011 |
+
next_idx=next_idx,
|
| 1012 |
+
):
|
| 1013 |
+
self._transition_scores[idx, next_idx] = self._transition_bias(
|
| 1014 |
+
prev_tag=tag,
|
| 1015 |
+
prev_span=span_label,
|
| 1016 |
+
next_tag=next_tag,
|
| 1017 |
+
next_span=next_span_label,
|
| 1018 |
+
background_span_idx=background_span_idx,
|
| 1019 |
+
biases=transition_biases,
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
@staticmethod
|
| 1023 |
+
def _is_valid_transition(
|
| 1024 |
+
*,
|
| 1025 |
+
prev_tag: str | None,
|
| 1026 |
+
prev_span: int | None,
|
| 1027 |
+
next_tag: str | None,
|
| 1028 |
+
next_span: int | None,
|
| 1029 |
+
background_token_idx: int,
|
| 1030 |
+
background_span_idx: int,
|
| 1031 |
+
next_idx: int,
|
| 1032 |
+
) -> bool:
|
| 1033 |
+
next_is_background = next_span == background_span_idx or next_idx == background_token_idx
|
| 1034 |
+
if (next_span is None or next_tag is None) and not next_is_background:
|
| 1035 |
+
return False
|
| 1036 |
+
|
| 1037 |
+
if prev_span is None or prev_tag is None:
|
| 1038 |
+
return next_is_background or next_tag in {"B", "S"}
|
| 1039 |
+
|
| 1040 |
+
prev_is_background = prev_span == background_span_idx
|
| 1041 |
+
if prev_is_background or prev_tag in {"E", "S"}:
|
| 1042 |
+
return next_is_background or next_tag in {"B", "S"}
|
| 1043 |
+
if prev_tag in {"B", "I"}:
|
| 1044 |
+
return prev_span == next_span and next_tag in {"I", "E"}
|
| 1045 |
+
return False
|
| 1046 |
+
|
| 1047 |
+
@staticmethod
|
| 1048 |
+
def _transition_bias(
|
| 1049 |
+
*,
|
| 1050 |
+
prev_tag: str | None,
|
| 1051 |
+
prev_span: int | None,
|
| 1052 |
+
next_tag: str | None,
|
| 1053 |
+
next_span: int | None,
|
| 1054 |
+
background_span_idx: int,
|
| 1055 |
+
biases: dict[str, float],
|
| 1056 |
+
) -> float:
|
| 1057 |
+
next_is_background = next_span == background_span_idx
|
| 1058 |
+
prev_is_background = prev_span == background_span_idx
|
| 1059 |
+
if prev_is_background:
|
| 1060 |
+
return (
|
| 1061 |
+
biases["transition_bias_background_stay"]
|
| 1062 |
+
if next_is_background
|
| 1063 |
+
else biases["transition_bias_background_to_start"]
|
| 1064 |
+
)
|
| 1065 |
+
if prev_tag in {"B", "I"}:
|
| 1066 |
+
return (
|
| 1067 |
+
biases["transition_bias_inside_to_continue"]
|
| 1068 |
+
if next_tag == "I"
|
| 1069 |
+
else biases["transition_bias_inside_to_end"]
|
| 1070 |
+
)
|
| 1071 |
+
return (
|
| 1072 |
+
biases["transition_bias_end_to_background"]
|
| 1073 |
+
if next_is_background
|
| 1074 |
+
else biases["transition_bias_end_to_start"]
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
def decode(self, token_logprobs: torch.Tensor) -> list[int]:
|
| 1078 |
+
if token_logprobs.ndim != 2:
|
| 1079 |
+
raise ValueError("token_logprobs must have shape [seq_len, num_classes]")
|
| 1080 |
+
seq_len, num_classes = token_logprobs.shape
|
| 1081 |
+
if seq_len == 0:
|
| 1082 |
+
return []
|
| 1083 |
+
|
| 1084 |
+
start_scores = self._start_scores.to(
|
| 1085 |
+
device=token_logprobs.device,
|
| 1086 |
+
dtype=token_logprobs.dtype,
|
| 1087 |
+
)
|
| 1088 |
+
end_scores = self._end_scores.to(
|
| 1089 |
+
device=token_logprobs.device,
|
| 1090 |
+
dtype=token_logprobs.dtype,
|
| 1091 |
+
)
|
| 1092 |
+
transition_scores = self._transition_scores.to(
|
| 1093 |
+
device=token_logprobs.device,
|
| 1094 |
+
dtype=token_logprobs.dtype,
|
| 1095 |
+
)
|
| 1096 |
+
scores = token_logprobs[0] + start_scores
|
| 1097 |
+
backpointers = torch.empty(
|
| 1098 |
+
(seq_len - 1, num_classes),
|
| 1099 |
+
device=token_logprobs.device,
|
| 1100 |
+
dtype=torch.int64,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
for idx in range(1, seq_len):
|
| 1104 |
+
transitions = scores.unsqueeze(1) + transition_scores
|
| 1105 |
+
best_scores, best_paths = transitions.max(dim=0)
|
| 1106 |
+
scores = best_scores + token_logprobs[idx]
|
| 1107 |
+
backpointers[idx - 1] = best_paths
|
| 1108 |
+
|
| 1109 |
+
if not torch.isfinite(scores).any():
|
| 1110 |
+
return token_logprobs.argmax(dim=1).tolist()
|
| 1111 |
+
|
| 1112 |
+
scores = scores + end_scores
|
| 1113 |
+
last_label = scores.argmax()
|
| 1114 |
+
path = torch.empty((seq_len,), device=token_logprobs.device, dtype=torch.int64)
|
| 1115 |
+
path[-1] = last_label
|
| 1116 |
+
for idx in range(seq_len - 2, -1, -1):
|
| 1117 |
+
last_label = backpointers[idx, last_label]
|
| 1118 |
+
path[idx] = last_label
|
| 1119 |
+
return path.tolist()
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
@torch.inference_mode()
|
| 1123 |
+
def predict_text(
|
| 1124 |
+
runtime: InferenceRuntime,
|
| 1125 |
+
text: str,
|
| 1126 |
+
decoder: Decoder,
|
| 1127 |
+
) -> tuple[str, list[dict[str, object]]]:
|
| 1128 |
+
token_ids = tuple(int(token) for token in runtime.encoding.encode(text, allowed_special="all"))
|
| 1129 |
+
if not token_ids:
|
| 1130 |
+
return text, []
|
| 1131 |
+
|
| 1132 |
+
if runtime.n_ctx <= 0:
|
| 1133 |
+
raise ValueError("runtime.n_ctx must be positive")
|
| 1134 |
+
|
| 1135 |
+
token_score_vectors: list[torch.Tensor] = []
|
| 1136 |
+
for start in range(0, len(token_ids), runtime.n_ctx):
|
| 1137 |
+
end = min(start + runtime.n_ctx, len(token_ids))
|
| 1138 |
+
window_tokens = torch.tensor(token_ids[start:end], device=runtime.device, dtype=torch.int32)
|
| 1139 |
+
logits = runtime.model(window_tokens)
|
| 1140 |
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
| 1141 |
+
if log_probs.shape[0] != window_tokens.shape[0]:
|
| 1142 |
+
raise ValueError("Logprob output length does not match window length")
|
| 1143 |
+
token_score_vectors.extend(log_probs.unbind(0))
|
| 1144 |
+
|
| 1145 |
+
if not token_score_vectors:
|
| 1146 |
+
return text, []
|
| 1147 |
+
|
| 1148 |
+
stacked_scores = torch.stack(token_score_vectors, dim=0)
|
| 1149 |
+
decoded_labels = decoder.decode(stacked_scores)
|
| 1150 |
+
if len(decoded_labels) != len(token_ids):
|
| 1151 |
+
decoded_labels = stacked_scores.argmax(dim=1).tolist()
|
| 1152 |
+
|
| 1153 |
+
predicted_labels_by_index = {
|
| 1154 |
+
token_idx: int(label) for token_idx, label in enumerate(decoded_labels)
|
| 1155 |
+
}
|
| 1156 |
+
predicted_token_spans = labels_to_spans(predicted_labels_by_index, runtime.label_info)
|
| 1157 |
+
token_bytes = [runtime.encoding.decode_single_token_bytes(token_id) for token_id in token_ids]
|
| 1158 |
+
decoded_text = b"".join(token_bytes).decode("utf-8", errors="replace")
|
| 1159 |
+
char_byte_starts: list[int] = []
|
| 1160 |
+
char_byte_ends: list[int] = []
|
| 1161 |
+
byte_cursor = 0
|
| 1162 |
+
for ch in decoded_text:
|
| 1163 |
+
char_byte_starts.append(byte_cursor)
|
| 1164 |
+
byte_cursor += len(ch.encode("utf-8"))
|
| 1165 |
+
char_byte_ends.append(byte_cursor)
|
| 1166 |
+
char_starts: list[int] = []
|
| 1167 |
+
char_ends: list[int] = []
|
| 1168 |
+
token_byte_cursor = 0
|
| 1169 |
+
for raw_bytes in token_bytes:
|
| 1170 |
+
token_byte_start = token_byte_cursor
|
| 1171 |
+
token_byte_end = token_byte_start + len(raw_bytes)
|
| 1172 |
+
token_byte_cursor = token_byte_end
|
| 1173 |
+
start_idx = bisect_right(char_byte_ends, token_byte_start)
|
| 1174 |
+
end_idx = bisect_left(char_byte_starts, token_byte_end)
|
| 1175 |
+
if end_idx < start_idx:
|
| 1176 |
+
end_idx = start_idx
|
| 1177 |
+
char_starts.append(start_idx)
|
| 1178 |
+
char_ends.append(end_idx)
|
| 1179 |
+
if char_ends and char_ends[-1] != len(decoded_text):
|
| 1180 |
+
raise ValueError(
|
| 1181 |
+
f"Character length mismatch for decoded text (tokens={char_ends[-1]}, text={len(decoded_text)})"
|
| 1182 |
+
)
|
| 1183 |
+
decoded_mismatch = decoded_text != text
|
| 1184 |
+
source_text = decoded_text if decoded_mismatch else text
|
| 1185 |
+
predicted_char_spans = token_spans_to_char_spans(
|
| 1186 |
+
predicted_token_spans,
|
| 1187 |
+
char_starts,
|
| 1188 |
+
char_ends,
|
| 1189 |
+
)
|
| 1190 |
+
predicted_char_spans = trim_char_spans_whitespace(predicted_char_spans, source_text)
|
| 1191 |
+
|
| 1192 |
+
detected: list[dict[str, object]] = []
|
| 1193 |
+
for label_idx, start, end in predicted_char_spans:
|
| 1194 |
+
if not (0 <= start < end <= len(source_text)):
|
| 1195 |
+
continue
|
| 1196 |
+
label = (
|
| 1197 |
+
runtime.label_info.span_class_names[label_idx]
|
| 1198 |
+
if 0 <= label_idx < len(runtime.label_info.span_class_names)
|
| 1199 |
+
else f"label_{label_idx}"
|
| 1200 |
+
)
|
| 1201 |
+
detected.append(
|
| 1202 |
+
{
|
| 1203 |
+
"entity": label,
|
| 1204 |
+
"start": int(start),
|
| 1205 |
+
"end": int(end),
|
| 1206 |
+
}
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
return source_text, detected
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
@spaces.GPU
|
| 1213 |
+
def predict(text: str) -> dict[str, object]:
|
| 1214 |
+
text = text or ""
|
| 1215 |
+
if not text.strip():
|
| 1216 |
+
return EMPTY_HIGHLIGHT_PAYLOAD
|
| 1217 |
+
runtime = get_runtime()
|
| 1218 |
+
decoder = Decoder(label_info=runtime.label_info)
|
| 1219 |
+
filtered_text, spans = predict_text(runtime, text, decoder)
|
| 1220 |
+
return {
|
| 1221 |
+
"text": filtered_text,
|
| 1222 |
+
"entities": spans,
|
| 1223 |
+
}
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def build_demo() -> gr.Blocks:
|
| 1227 |
+
config_path = MODEL_DIR / "config.json"
|
| 1228 |
+
checkpoint_config = json.loads(config_path.read_text(encoding="utf-8"))
|
| 1229 |
+
if not isinstance(checkpoint_config, dict):
|
| 1230 |
+
raise ValueError(f"Invalid checkpoint config payload at {config_path}")
|
| 1231 |
+
validate_model_config_contract(
|
| 1232 |
+
checkpoint_config,
|
| 1233 |
+
context=str(config_path),
|
| 1234 |
+
)
|
| 1235 |
+
span_class_names = SPAN_CLASS_NAMES
|
| 1236 |
+
web_color_palette = (
|
| 1237 |
+
"#e6194b",
|
| 1238 |
+
"#3cb44b",
|
| 1239 |
+
"#4363d8",
|
| 1240 |
+
"#f58231",
|
| 1241 |
+
"#911eb4",
|
| 1242 |
+
"#008080",
|
| 1243 |
+
"#9a6324",
|
| 1244 |
+
"#f032e6",
|
| 1245 |
+
"#b59f00",
|
| 1246 |
+
"#800000",
|
| 1247 |
+
"#000075",
|
| 1248 |
+
"#808080",
|
| 1249 |
+
)
|
| 1250 |
+
with gr.Blocks(
|
| 1251 |
+
title="OpenAI Privacy Filter",
|
| 1252 |
+
fill_width=True,
|
| 1253 |
+
elem_id="privacy-filter-app",
|
| 1254 |
+
) as demo:
|
| 1255 |
+
gr.Markdown("# OpenAI Privacy Filter Demo")
|
| 1256 |
+
gr.Markdown("Example of using OpenAI Privacy Filter (OPF) to mask personal identifiers.")
|
| 1257 |
+
|
| 1258 |
+
with gr.Column(variant="panel"):
|
| 1259 |
+
gr.Markdown("Input text:")
|
| 1260 |
+
input_text = gr.Textbox(
|
| 1261 |
+
lines=2,
|
| 1262 |
+
placeholder="Paste text here to detect and mask personal identifiers...",
|
| 1263 |
+
show_label=False,
|
| 1264 |
+
container=False,
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
with gr.Column(variant="panel"):
|
| 1268 |
+
gr.Markdown("Text after masking personal identifiers:")
|
| 1269 |
+
output_text = gr.HighlightedText(
|
| 1270 |
+
value=EMPTY_HIGHLIGHT_PAYLOAD,
|
| 1271 |
+
color_map={
|
| 1272 |
+
label: web_color_palette[idx % len(web_color_palette)]
|
| 1273 |
+
for idx, label in enumerate(
|
| 1274 |
+
label for label in span_class_names if label != BACKGROUND_CLASS_LABEL
|
| 1275 |
+
)
|
| 1276 |
+
},
|
| 1277 |
+
combine_adjacent=False,
|
| 1278 |
+
show_legend=False,
|
| 1279 |
+
show_label=False,
|
| 1280 |
+
container=True,
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
with gr.Row():
|
| 1284 |
+
submit_button = gr.Button("Submit", variant="primary")
|
| 1285 |
+
clear_button = gr.Button("Clear")
|
| 1286 |
+
|
| 1287 |
+
submit_button.click(
|
| 1288 |
+
fn=predict,
|
| 1289 |
+
inputs=input_text,
|
| 1290 |
+
outputs=output_text,
|
| 1291 |
+
api_name="predict",
|
| 1292 |
+
)
|
| 1293 |
+
input_text.submit(
|
| 1294 |
+
fn=predict,
|
| 1295 |
+
inputs=input_text,
|
| 1296 |
+
outputs=output_text,
|
| 1297 |
+
)
|
| 1298 |
+
clear_button.click(
|
| 1299 |
+
lambda: ("", EMPTY_HIGHLIGHT_PAYLOAD),
|
| 1300 |
+
outputs=[input_text, output_text],
|
| 1301 |
+
show_progress="hidden",
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
gr.Examples(
|
| 1305 |
+
examples=[
|
| 1306 |
+
["Alice was born on 1990-01-02 and lives at 1 Main St."],
|
| 1307 |
+
["Email me at alice@example.com or call 415-555-0101."],
|
| 1308 |
+
],
|
| 1309 |
+
inputs=input_text,
|
| 1310 |
+
outputs=output_text,
|
| 1311 |
+
fn=predict,
|
| 1312 |
+
cache_examples=False,
|
| 1313 |
+
)
|
| 1314 |
+
return demo
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
if __name__ == "__main__":
|
| 1318 |
+
demo = build_demo()
|
| 1319 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=6.9.0,<7
|
| 2 |
+
safetensors>=0.7.0,<1
|
| 3 |
+
spaces>=0.47.0,<1
|
| 4 |
+
tiktoken>=0.12.0,<1
|
| 5 |
+
torch
|