Adaptive NER: Zero-Shot Named Entity Recognition with Dynamic Labels

An adaptive-label NER model built on openai/privacy-filter (1.4B sparse MoE) and Snowflake/snowflake-arctic-embed-l-v2.0 (568M). Accepts any entity label at inference time with no retraining -- type in "dinosaur species", "gene mutation", or "spell name" and it just works.

Trained on 1M rows spanning 252K unique entity types across PII, cybersecurity, biomedical, legal, and general-domain datasets.

How It Works

Standard NER models have a fixed label set baked into their classification head. This model replaces that with dot-product scoring: instead of hidden -> Linear -> logits, it computes hidden . label_embedding -> score. The label embedding comes from a separate encoder that reads the label's natural language description.

4x Label text --> [Arctic-embed-l (568M)] --> (1024,) --> [Proj 1024->640] --+
  "beginning of email"                                                       |-- (4, 640) label matrix
  "continuation of email"                                                    |
  "end of email"                                                             |
  "complete email"                                                           +
                                                                             |
Input text --> [OpenAI NER backbone (1.4B MoE)] --> (T, 640) -- dot product --> Viterbi --> BIOES tags

Components

Component Model Params Role
NER backbone openai/privacy-filter 1.4B (sparse MoE, ~50M active) Per-token hidden states (T, 640)
Label encoder snowflake-arctic-embed-l-v2.0 568M (XLM-RoBERTa) Encodes label descriptions -> (1024,) CLS vectors
Projection nn.Linear(1024, 640) 655K Bridges Arctic's 1024-dim to backbone's 640-dim
O-embed nn.Parameter(640) 640 Learnable embedding for the "no entity" tag
Viterbi decoder Dynamic transition matrix -- Enforces valid BIOES sequences at decode time

BIOES Tagging

Each label generates 4 embeddings via templated text:

  • B -> "beginning of {label}" -- entity start
  • I -> "continuation of {label}" -- middle tokens
  • E -> "end of {label}" -- entity end
  • S -> "complete {label}" -- single-token entities

The Viterbi decoder enforces valid transitions: B->I, B->E, I->I, I->E, S->O, E->O. Cross-label jumps (B-email -> I-phone) are forbidden.

VtV Residual Injection

Beyond simple dot-product scoring, the model uses label-conditioned hidden state modification. The label encoder's intermediate representations (from layers 6, 12, 18, 24) are projected and used to nudge the backbone's hidden states:

  1. Arctic CLS vectors from intermediate layers are projected (1024->640) and BIOES-averaged per label -> V (N, 640)
  2. Compute VtV = V^T @ V -> a (640, 640) matrix per injection point
  3. Inject between backbone layers 0, 2, 4, 6:
hidden = (1 - sigmoid(alpha)) * hidden + sigmoid(alpha) * (hidden @ VtV / N)

Four learnable alpha scalars control injection strength. This adds only ~2.6M parameters on the Arctic side.

Repo Structure

original/                        # PyTorch (full precision, bf16)
  arctic/                        # Fine-tuned Arctic-embed-l encoder
    config.json
    model.safetensors            # 1.1 GB
    tokenizer.json
    tokenizer_config.json
  backbone/                      # Fine-tuned OpenAI NER backbone
    config.json
    model.safetensors            # 2.6 GB
    tokenizer.json
    tokenizer_config.json
  label_encoder_extra.pt         # 13 MB -- projection, O-embed, prefix projections, alphas

onnx/                            # ONNX (quantized, for browser/edge inference)
  arctic/
    encoder_int8.onnx            # 542 MB -- int8 quantized Arctic encoder
    projection.npz               # 2.6 MB -- projection weights + O-embed
    tokenizer.json               # Arctic tokenizer (XLMRobertaTokenizer)
    tokenizer_config.json
  backbone/
    model_quantized_hidden.onnx  # 159 KB -- graph only
    model_quantized.onnx_data    # 1.6 GB -- int8 quantized backbone weights
    tokenizer.json               # Backbone tokenizer
    tokenizer_config.json

Usage (PyTorch)

import torch
import torch.nn as nn
from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer

TEMPLATES = {
    "B": "beginning of {label}",
    "I": "continuation of {label}",
    "E": "end of {label}",
    "S": "complete {label}",
}

# Load models (download original/ folder from this repo)
backbone_model = AutoModelForTokenClassification.from_pretrained(
    "path/to/original/backbone", trust_remote_code=True, dtype=torch.float32
)
arctic_model = AutoModel.from_pretrained(
    "path/to/original/arctic", dtype=torch.float32
)
extra = torch.load("path/to/original/label_encoder_extra.pt", map_location="cpu", weights_only=True)

bb_tok = AutoTokenizer.from_pretrained("path/to/original/backbone", trust_remote_code=True)
arctic_tok = AutoTokenizer.from_pretrained("path/to/original/arctic")

# Build projection layer
proj = nn.Linear(1024, 640)
proj.weight.data = extra["proj.weight"]
proj.bias.data = extra["proj.bias"]
o_embed = extra["o_embed"]  # (640,)

# Define your labels -- any text works
labels = ["dinosaur species", "geological period", "continent"]
text = "Tyrannosaurus Rex lived during the late Cretaceous period in North America"

# Encode labels
label_texts = []
for label in labels:
    for pfx in ("B", "I", "E", "S"):
        label_texts.append(TEMPLATES[pfx].format(label=label))

le_enc = arctic_tok(label_texts, padding=True, truncation=True, max_length=32, return_tensors="pt")
with torch.no_grad():
    out = arctic_model(le_enc["input_ids"], le_enc["attention_mask"])
    cls_vecs = out.last_hidden_state[:, 0]  # (N*4, 1024)
    projected = proj(cls_vecs)              # (N*4, 640)
    embeds = torch.cat([o_embed.unsqueeze(0), projected])  # (1 + N*4, 640)

# Encode text and get hidden states
enc = bb_tok(text, add_special_tokens=False, return_tensors="pt")

# Hook into backbone to get pre-classifier hidden states
hidden = None
for name, m in backbone_model.named_modules():
    if isinstance(m, nn.Linear) and m.out_features == backbone_model.config.num_labels:
        def hook(module, args): 
            nonlocal hidden
            hidden = args[0]
        m.register_forward_pre_hook(hook)
        break

with torch.no_grad():
    backbone_model(input_ids=enc["input_ids"], attention_mask=torch.ones_like(enc["input_ids"]))

# Score and decode
scores = hidden[0] @ embeds.T  # (T, 1 + N*4)
tags = scores.argmax(dim=-1).tolist()  # or use Viterbi for valid BIOES sequences

Usage (ONNX -- Browser/Edge)

The onnx/ folder contains int8-quantized models for browser inference via ONNX Runtime Web:

import * as ort from 'onnxruntime-web';

// Backbone MUST use WASM (fp32) -- WebGPU uses fp16 internally which
// overflows during RMS norm (pre-norm values ~7000, squared = 49M > fp16 max 65504)
const bb = await ort.InferenceSession.create(bbOnnxUrl, {
    executionProviders: ['wasm'],
    externalData: [{ data: bbDataBuf, path: 'model_quantized.onnx_data' }],
});

// Arctic label encoder can use WebGPU (values stay in safe fp16 range)
const arctic = await ort.InferenceSession.create(arcticOnnxUrl, {
    executionProviders: ['webgpu', 'wasm'],
});

// Load projection weights from projection.npz
// Contains: proj_w (640, 1024), proj_b (640,), o_embed (640,)
const projNpz = await parseNpz(await fetch(projectionNpzUrl).then(r => r.arrayBuffer()));

// Tokenize labels using BIOES templates
const TEMPLATES = {B:'beginning of {l}', I:'continuation of {l}', E:'end of {l}', S:'complete {l}'};
// For each label, encode 4 template strings through Arctic -> project -> concat with O-embed
// Then: scores = hidden @ embeds.T, decode with Viterbi

Important: The backbone must run on WASM, not WebGPU. WebGPU's fp16 compute causes overflow in the backbone's RMS normalization layer, corrupting hidden states and making all tokens score as entities. The Arctic encoder is safe for WebGPU since its internal values stay within fp16 range.

Note: The ONNX path does not include VtV injection (only dot-product scoring + Viterbi), so results may differ slightly from the full PyTorch pipeline.

Performance

Apple M4 Pro (CPU, PyTorch, float32): 146 tok/s average, 213ms per inference

Test Domain Tokens Time Tok/s
C2 infra Cybersecurity 30 911ms 32.9
Leaked secrets Security 45 321ms 140.2
CVE disclosure Security 34 270ms 125.9
Contact info PII 28 170ms 164.6
Financial IDs PII 29 138ms 209.6
Dinosaurs Science 36 209ms 172.3
Basketball Sports 24 143ms 168.3
Classical music Arts 18 115ms 156.8
Recipe General 26 130ms 200.2
Oncology Medical 30 135ms 222.6
Space Science 32 127ms 251.6
Fantasy writing Creative 35 121ms 288.4
Contract Legal 36 140ms 257.9
Car specs Automotive 45 147ms 306.9
Negative (no entities) Control 18 113ms 159.9

Entity recall: 36/54 (67%), macro-F1: 0.66 across 15 diverse test cases spanning cybersecurity, PII, science, sports, medical, legal, and creative domains.

Training Data

1M rows, 252K unique entity types from 9 sources: NuNER v2 (876K rows, 245K types), Pile-NER, AI4Privacy, PIILO, CyberNER, APTNER, AskNews-NER, CRAPII, plus 10K hard negatives.

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for AmanPriyanshu/openai-privacy-filter-generalized-for-dynamic-labels-ner

Quantized
(12)
this model