Adaptive NER: Zero-Shot Named Entity Recognition with Dynamic Labels
An adaptive-label NER model built on openai/privacy-filter (1.4B sparse MoE) and Snowflake/snowflake-arctic-embed-l-v2.0 (568M). Accepts any entity label at inference time with no retraining -- type in "dinosaur species", "gene mutation", or "spell name" and it just works.
Trained on 1M rows spanning 252K unique entity types across PII, cybersecurity, biomedical, legal, and general-domain datasets.
How It Works
Standard NER models have a fixed label set baked into their classification head. This model replaces that with dot-product scoring: instead of hidden -> Linear -> logits, it computes hidden . label_embedding -> score. The label embedding comes from a separate encoder that reads the label's natural language description.
4x Label text --> [Arctic-embed-l (568M)] --> (1024,) --> [Proj 1024->640] --+
"beginning of email" |-- (4, 640) label matrix
"continuation of email" |
"end of email" |
"complete email" +
|
Input text --> [OpenAI NER backbone (1.4B MoE)] --> (T, 640) -- dot product --> Viterbi --> BIOES tags
Components
| Component | Model | Params | Role |
|---|---|---|---|
| NER backbone | openai/privacy-filter |
1.4B (sparse MoE, ~50M active) | Per-token hidden states (T, 640) |
| Label encoder | snowflake-arctic-embed-l-v2.0 |
568M (XLM-RoBERTa) | Encodes label descriptions -> (1024,) CLS vectors |
| Projection | nn.Linear(1024, 640) |
655K | Bridges Arctic's 1024-dim to backbone's 640-dim |
| O-embed | nn.Parameter(640) |
640 | Learnable embedding for the "no entity" tag |
| Viterbi decoder | Dynamic transition matrix | -- | Enforces valid BIOES sequences at decode time |
BIOES Tagging
Each label generates 4 embeddings via templated text:
- B ->
"beginning of {label}"-- entity start - I ->
"continuation of {label}"-- middle tokens - E ->
"end of {label}"-- entity end - S ->
"complete {label}"-- single-token entities
The Viterbi decoder enforces valid transitions: B->I, B->E, I->I, I->E, S->O, E->O. Cross-label jumps (B-email -> I-phone) are forbidden.
VtV Residual Injection
Beyond simple dot-product scoring, the model uses label-conditioned hidden state modification. The label encoder's intermediate representations (from layers 6, 12, 18, 24) are projected and used to nudge the backbone's hidden states:
- Arctic CLS vectors from intermediate layers are projected (1024->640) and BIOES-averaged per label -> V (N, 640)
- Compute VtV = V^T @ V -> a (640, 640) matrix per injection point
- Inject between backbone layers 0, 2, 4, 6:
hidden = (1 - sigmoid(alpha)) * hidden + sigmoid(alpha) * (hidden @ VtV / N)
Four learnable alpha scalars control injection strength. This adds only ~2.6M parameters on the Arctic side.
Repo Structure
original/ # PyTorch (full precision, bf16)
arctic/ # Fine-tuned Arctic-embed-l encoder
config.json
model.safetensors # 1.1 GB
tokenizer.json
tokenizer_config.json
backbone/ # Fine-tuned OpenAI NER backbone
config.json
model.safetensors # 2.6 GB
tokenizer.json
tokenizer_config.json
label_encoder_extra.pt # 13 MB -- projection, O-embed, prefix projections, alphas
onnx/ # ONNX (quantized, for browser/edge inference)
arctic/
encoder_int8.onnx # 542 MB -- int8 quantized Arctic encoder
projection.npz # 2.6 MB -- projection weights + O-embed
tokenizer.json # Arctic tokenizer (XLMRobertaTokenizer)
tokenizer_config.json
backbone/
model_quantized_hidden.onnx # 159 KB -- graph only
model_quantized.onnx_data # 1.6 GB -- int8 quantized backbone weights
tokenizer.json # Backbone tokenizer
tokenizer_config.json
Usage (PyTorch)
import torch
import torch.nn as nn
from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer
TEMPLATES = {
"B": "beginning of {label}",
"I": "continuation of {label}",
"E": "end of {label}",
"S": "complete {label}",
}
# Load models (download original/ folder from this repo)
backbone_model = AutoModelForTokenClassification.from_pretrained(
"path/to/original/backbone", trust_remote_code=True, dtype=torch.float32
)
arctic_model = AutoModel.from_pretrained(
"path/to/original/arctic", dtype=torch.float32
)
extra = torch.load("path/to/original/label_encoder_extra.pt", map_location="cpu", weights_only=True)
bb_tok = AutoTokenizer.from_pretrained("path/to/original/backbone", trust_remote_code=True)
arctic_tok = AutoTokenizer.from_pretrained("path/to/original/arctic")
# Build projection layer
proj = nn.Linear(1024, 640)
proj.weight.data = extra["proj.weight"]
proj.bias.data = extra["proj.bias"]
o_embed = extra["o_embed"] # (640,)
# Define your labels -- any text works
labels = ["dinosaur species", "geological period", "continent"]
text = "Tyrannosaurus Rex lived during the late Cretaceous period in North America"
# Encode labels
label_texts = []
for label in labels:
for pfx in ("B", "I", "E", "S"):
label_texts.append(TEMPLATES[pfx].format(label=label))
le_enc = arctic_tok(label_texts, padding=True, truncation=True, max_length=32, return_tensors="pt")
with torch.no_grad():
out = arctic_model(le_enc["input_ids"], le_enc["attention_mask"])
cls_vecs = out.last_hidden_state[:, 0] # (N*4, 1024)
projected = proj(cls_vecs) # (N*4, 640)
embeds = torch.cat([o_embed.unsqueeze(0), projected]) # (1 + N*4, 640)
# Encode text and get hidden states
enc = bb_tok(text, add_special_tokens=False, return_tensors="pt")
# Hook into backbone to get pre-classifier hidden states
hidden = None
for name, m in backbone_model.named_modules():
if isinstance(m, nn.Linear) and m.out_features == backbone_model.config.num_labels:
def hook(module, args):
nonlocal hidden
hidden = args[0]
m.register_forward_pre_hook(hook)
break
with torch.no_grad():
backbone_model(input_ids=enc["input_ids"], attention_mask=torch.ones_like(enc["input_ids"]))
# Score and decode
scores = hidden[0] @ embeds.T # (T, 1 + N*4)
tags = scores.argmax(dim=-1).tolist() # or use Viterbi for valid BIOES sequences
Usage (ONNX -- Browser/Edge)
The onnx/ folder contains int8-quantized models for browser inference via ONNX Runtime Web:
import * as ort from 'onnxruntime-web';
// Backbone MUST use WASM (fp32) -- WebGPU uses fp16 internally which
// overflows during RMS norm (pre-norm values ~7000, squared = 49M > fp16 max 65504)
const bb = await ort.InferenceSession.create(bbOnnxUrl, {
executionProviders: ['wasm'],
externalData: [{ data: bbDataBuf, path: 'model_quantized.onnx_data' }],
});
// Arctic label encoder can use WebGPU (values stay in safe fp16 range)
const arctic = await ort.InferenceSession.create(arcticOnnxUrl, {
executionProviders: ['webgpu', 'wasm'],
});
// Load projection weights from projection.npz
// Contains: proj_w (640, 1024), proj_b (640,), o_embed (640,)
const projNpz = await parseNpz(await fetch(projectionNpzUrl).then(r => r.arrayBuffer()));
// Tokenize labels using BIOES templates
const TEMPLATES = {B:'beginning of {l}', I:'continuation of {l}', E:'end of {l}', S:'complete {l}'};
// For each label, encode 4 template strings through Arctic -> project -> concat with O-embed
// Then: scores = hidden @ embeds.T, decode with Viterbi
Important: The backbone must run on WASM, not WebGPU. WebGPU's fp16 compute causes overflow in the backbone's RMS normalization layer, corrupting hidden states and making all tokens score as entities. The Arctic encoder is safe for WebGPU since its internal values stay within fp16 range.
Note: The ONNX path does not include VtV injection (only dot-product scoring + Viterbi), so results may differ slightly from the full PyTorch pipeline.
Performance
Apple M4 Pro (CPU, PyTorch, float32): 146 tok/s average, 213ms per inference
| Test | Domain | Tokens | Time | Tok/s |
|---|---|---|---|---|
| C2 infra | Cybersecurity | 30 | 911ms | 32.9 |
| Leaked secrets | Security | 45 | 321ms | 140.2 |
| CVE disclosure | Security | 34 | 270ms | 125.9 |
| Contact info | PII | 28 | 170ms | 164.6 |
| Financial IDs | PII | 29 | 138ms | 209.6 |
| Dinosaurs | Science | 36 | 209ms | 172.3 |
| Basketball | Sports | 24 | 143ms | 168.3 |
| Classical music | Arts | 18 | 115ms | 156.8 |
| Recipe | General | 26 | 130ms | 200.2 |
| Oncology | Medical | 30 | 135ms | 222.6 |
| Space | Science | 32 | 127ms | 251.6 |
| Fantasy writing | Creative | 35 | 121ms | 288.4 |
| Contract | Legal | 36 | 140ms | 257.9 |
| Car specs | Automotive | 45 | 147ms | 306.9 |
| Negative (no entities) | Control | 18 | 113ms | 159.9 |
Entity recall: 36/54 (67%), macro-F1: 0.66 across 15 diverse test cases spanning cybersecurity, PII, science, sports, medical, legal, and creative domains.
Training Data
1M rows, 252K unique entity types from 9 sources: NuNER v2 (876K rows, 245K types), Pile-NER, AI4Privacy, PIILO, CyberNER, APTNER, AskNews-NER, CRAPII, plus 10K hard negatives.
License
Apache 2.0
Model tree for AmanPriyanshu/openai-privacy-filter-generalized-for-dynamic-labels-ner
Base model
Snowflake/snowflake-arctic-embed-l-v2.0