Manikrishneshwar Sasidhar commited on
Commit
bfcecff
·
verified ·
1 Parent(s): 693eef2

Initial upload: BERT+GAT PII redactor

Browse files
README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ tags:
7
+ - pii
8
+ - privacy
9
+ - redaction
10
+ - token-classification
11
+ - ner
12
+ - bert
13
+ - gat
14
+ - graph-attention-network
15
+ pipeline_tag: token-classification
16
+ ---
17
+
18
+ # PII Redactor — BERT + Graph Attention Network
19
+
20
+ Token-level PII detection model that combines a BERT contextual encoder
21
+ with a Graph Attention Network (GAT) refinement stage. The graph mixes
22
+ sequential-window edges with top-k attention edges drawn from BERT's last
23
+ layer, letting the GAT exploit both locality and the long-range
24
+ dependencies BERT already discovered.
25
+
26
+ The model emits BIO tags over 15 PII categories: `SSN`, `BANK_ACCOUNT`,
27
+ `ROUTING_NUMBER`, `CREDIT_CARD`, `CVV`, `CARD_EXPIRY`, `IBAN`, `DOB`,
28
+ `FULL_NAME`, `EMAIL`, `PHONE`, `ADDRESS`, `PASSPORT`, `DRIVERS_LICENSE`,
29
+ `TAX_ID`.
30
+
31
+ ## Quick start
32
+
33
+ ```python
34
+ from transformers import AutoModel, AutoTokenizer
35
+
36
+ REPO = "your-username/pii-redactor-bert-gat" # <-- replace
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True)
39
+ model = AutoModel.from_pretrained(REPO, trust_remote_code=True)
40
+ model.eval()
41
+
42
+ result = model.predict(
43
+ "Email me at john.doe@example.com or call 555-123-4567.",
44
+ tokenizer,
45
+ )
46
+ print(result["redacted"])
47
+ # -> "Email me at [EMAIL] or call [PHONE]."
48
+ print(result["spans"])
49
+ # -> [{'start': 12, 'end': 32, 'label': 'EMAIL', 'value': 'john.doe@example.com'}, ...]
50
+ ```
51
+
52
+ `trust_remote_code=True` is required because the architecture (BERT + GAT)
53
+ is custom and ships as `modeling_bert_gat.py` in this repository.
54
+
55
+ ## Architecture
56
+
57
+ ```
58
+ input_ids ──► BERT encoder (with output_attentions=True)
59
+
60
+
61
+ token embeddings + last-layer attention
62
+
63
+
64
+ build_token_graph(window=3, top_k=5)
65
+
66
+
67
+ stack of GATConv layers (heads=4, hidden=128)
68
+
69
+
70
+ residual + LayerNorm ──► classifier ──► BIO logits
71
+ ```
72
+
73
+ ## Inputs / outputs
74
+
75
+ * **Input:** raw text string.
76
+ * **Output:** dict with `original`, `redacted`, and `spans` (list of
77
+ `{start, end, label, value}`).
78
+
79
+ ## Intended use
80
+
81
+ * Pre-processing user-generated text before logging or storing.
82
+ * Building privacy-preserving data pipelines.
83
+ * Demonstrating BERT + graph-network hybrids for NER.
84
+
85
+ ## Limitations
86
+
87
+ * Trained on synthetic English PII; real-world distributions may differ.
88
+ * Latency is higher than vanilla BERT-NER because the graph is built and
89
+ the GAT runs per sample.
90
+ * Coverage is limited to the 15 categories above.
91
+
92
+ ## Requirements
93
+
94
+ ```text
95
+ torch>=2.0
96
+ transformers>=4.30
97
+ torch-geometric>=2.3
98
+ ```
99
+
100
+ ## License
101
+
102
+ MIT.
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertGATForTokenClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_bert_gat.BertGATConfig",
7
+ "AutoModel": "modeling_bert_gat.BertGATForTokenClassification"
8
+ },
9
+ "bert_model_name": "distilbert-base-uncased",
10
+ "dropout": 0.0,
11
+ "dtype": "float32",
12
+ "gat_heads": 4,
13
+ "gat_hidden": 128,
14
+ "gat_layers": 2,
15
+ "id2label": {
16
+ "0": "LABEL_0",
17
+ "1": "LABEL_1",
18
+ "2": "LABEL_2",
19
+ "3": "LABEL_3",
20
+ "4": "LABEL_4",
21
+ "5": "LABEL_5",
22
+ "6": "LABEL_6",
23
+ "7": "LABEL_7",
24
+ "8": "LABEL_8",
25
+ "9": "LABEL_9",
26
+ "10": "LABEL_10",
27
+ "11": "LABEL_11",
28
+ "12": "LABEL_12",
29
+ "13": "LABEL_13",
30
+ "14": "LABEL_14",
31
+ "15": "LABEL_15",
32
+ "16": "LABEL_16",
33
+ "17": "LABEL_17",
34
+ "18": "LABEL_18",
35
+ "19": "LABEL_19",
36
+ "20": "LABEL_20",
37
+ "21": "LABEL_21",
38
+ "22": "LABEL_22",
39
+ "23": "LABEL_23",
40
+ "24": "LABEL_24",
41
+ "25": "LABEL_25",
42
+ "26": "LABEL_26",
43
+ "27": "LABEL_27",
44
+ "28": "LABEL_28",
45
+ "29": "LABEL_29",
46
+ "30": "LABEL_30"
47
+ },
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1,
51
+ "LABEL_10": 10,
52
+ "LABEL_11": 11,
53
+ "LABEL_12": 12,
54
+ "LABEL_13": 13,
55
+ "LABEL_14": 14,
56
+ "LABEL_15": 15,
57
+ "LABEL_16": 16,
58
+ "LABEL_17": 17,
59
+ "LABEL_18": 18,
60
+ "LABEL_19": 19,
61
+ "LABEL_2": 2,
62
+ "LABEL_20": 20,
63
+ "LABEL_21": 21,
64
+ "LABEL_22": 22,
65
+ "LABEL_23": 23,
66
+ "LABEL_24": 24,
67
+ "LABEL_25": 25,
68
+ "LABEL_26": 26,
69
+ "LABEL_27": 27,
70
+ "LABEL_28": 28,
71
+ "LABEL_29": 29,
72
+ "LABEL_3": 3,
73
+ "LABEL_30": 30,
74
+ "LABEL_4": 4,
75
+ "LABEL_5": 5,
76
+ "LABEL_6": 6,
77
+ "LABEL_7": 7,
78
+ "LABEL_8": 8,
79
+ "LABEL_9": 9
80
+ },
81
+ "max_length": 256,
82
+ "model_type": "bert_gat_pii",
83
+ "top_k_attn": 5,
84
+ "transformers_version": "5.1.0",
85
+ "window": 3
86
+ }
configuration_bert_gat.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-compatible config for the BERT+GAT PII redactor.
3
+
4
+ When the model repo is loaded with ``trust_remote_code=True``,
5
+ ``transformers`` will instantiate this class from ``config.json``.
6
+ """
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class BertGATConfig(PretrainedConfig):
12
+ model_type = "bert_gat_pii"
13
+
14
+ def __init__(
15
+ self,
16
+ bert_model_name: str = "distilbert-base-uncased",
17
+ num_labels: int = 31,
18
+ gat_heads: int = 4,
19
+ gat_hidden: int = 128,
20
+ gat_layers: int = 2,
21
+ dropout: float = 0.1,
22
+ window: int = 3,
23
+ top_k_attn: int = 5,
24
+ max_length: int = 512,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.bert_model_name = bert_model_name
29
+ self.num_labels = num_labels
30
+ self.gat_heads = gat_heads
31
+ self.gat_hidden = gat_hidden
32
+ self.gat_layers = gat_layers
33
+ self.dropout = dropout
34
+ self.window = window
35
+ self.top_k_attn = top_k_attn
36
+ self.max_length = max_length
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c8cf6d12b41debd9a1a5d1b92360b35affd6fdae539109c39539e6db608d70e
3
+ size 269749308
modeling_bert_gat.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-compatible wrapper around BertGATPIIModel.
3
+
4
+ Self-contained on purpose: the Hub repo doesn't import ``pii_redactor``,
5
+ so we redeclare the architecture here. This is the file ``transformers``
6
+ loads when a user does::
7
+
8
+ from transformers import AutoModel
9
+ model = AutoModel.from_pretrained(
10
+ "your-username/pii-redactor-bert-gat", trust_remote_code=True
11
+ )
12
+ """
13
+
14
+ from typing import List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel
20
+
21
+ # Dual-mode import: works both when this file is loaded as part of a
22
+ # package (HuggingFace's ``trust_remote_code=True`` flow) and when it's
23
+ # imported as a sibling module by a script like ``convert_checkpoint.py``.
24
+ try:
25
+ from .configuration_bert_gat import BertGATConfig
26
+ except ImportError:
27
+ from configuration_bert_gat import BertGATConfig
28
+
29
+ try:
30
+ from torch_geometric.nn import GATConv
31
+ except ImportError as e: # pragma: no cover
32
+ raise ImportError(
33
+ "torch-geometric is required. Install with: pip install torch-geometric"
34
+ ) from e
35
+
36
+
37
+ # --------------------------------------------------------------------------- #
38
+ # Label space (kept in sync with pii_redactor.config)
39
+ # --------------------------------------------------------------------------- #
40
+ PII_TYPES = [
41
+ "SSN", "BANK_ACCOUNT", "ROUTING_NUMBER", "CREDIT_CARD", "CVV",
42
+ "CARD_EXPIRY", "IBAN", "DOB", "FULL_NAME", "EMAIL", "PHONE",
43
+ "ADDRESS", "PASSPORT", "DRIVERS_LICENSE", "TAX_ID",
44
+ ]
45
+ LABELS = ["O"] + sum(([f"B-{t}", f"I-{t}"] for t in PII_TYPES), [])
46
+ ID2LABEL = {i: l for i, l in enumerate(LABELS)}
47
+
48
+
49
+ # --------------------------------------------------------------------------- #
50
+ # Graph builder (mirrors pii_redactor.models.graph_builder)
51
+ # --------------------------------------------------------------------------- #
52
+ def _build_token_graph(
53
+ seq_len: int,
54
+ attn_weights: torch.Tensor,
55
+ window: int,
56
+ top_k: int,
57
+ device: torch.device,
58
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ src_list, dst_list, wt_list = [], [], []
60
+
61
+ for i in range(seq_len):
62
+ for j in range(max(0, i - window), min(seq_len, i + window + 1)):
63
+ if i != j:
64
+ src_list.append(i)
65
+ dst_list.append(j)
66
+ wt_list.append(1.0)
67
+
68
+ avg_attn = attn_weights.mean(dim=0)
69
+ topk_vals, topk_idx = avg_attn.topk(min(top_k, seq_len), dim=-1)
70
+ for i in range(seq_len):
71
+ for ki in range(topk_idx.shape[1]):
72
+ j = topk_idx[i, ki].item()
73
+ wt = topk_vals[i, ki].item()
74
+ if i != j and wt > 1e-4:
75
+ src_list.append(i)
76
+ dst_list.append(j)
77
+ wt_list.append(wt)
78
+
79
+ edge_index = torch.tensor([src_list, dst_list], dtype=torch.long, device=device)
80
+ edge_attr = torch.tensor(wt_list, dtype=torch.float, device=device).unsqueeze(1)
81
+ return edge_index, edge_attr
82
+
83
+
84
+ # --------------------------------------------------------------------------- #
85
+ # Model
86
+ # --------------------------------------------------------------------------- #
87
+ class BertGATForTokenClassification(PreTrainedModel):
88
+ config_class = BertGATConfig
89
+ base_model_prefix = "bert_gat_pii"
90
+
91
+ # This model has no tied weights (no shared embeddings, no encoder-
92
+ # decoder). Different transformers versions look for either the
93
+ # old ``_tied_weights_keys`` (list) or the newer
94
+ # ``all_tied_weights_keys`` (dict); declaring both empty keeps
95
+ # ``from_pretrained``'s post-load tied-weight bookkeeping happy.
96
+ _tied_weights_keys: list = []
97
+ all_tied_weights_keys: dict = {}
98
+
99
+ def __init__(self, config: BertGATConfig):
100
+ super().__init__(config)
101
+
102
+ # Instantiate the BERT trunk EMPTY (no weight download here). The
103
+ # outer ``from_pretrained`` populates everything — including these
104
+ # parameters — from the saved state dict. Calling
105
+ # ``AutoModel.from_pretrained`` here would clash with the meta-
106
+ # device context the outer loader sets up.
107
+ bert_config = AutoConfig.from_pretrained(config.bert_model_name)
108
+ bert_config.output_attentions = True
109
+ self.bert = AutoModel.from_config(bert_config)
110
+ bert_dim = self.bert.config.hidden_size
111
+
112
+ self.dropout = nn.Dropout(config.dropout)
113
+ self.window = config.window
114
+ self.top_k = config.top_k_attn
115
+
116
+ self.gat_layers = nn.ModuleList()
117
+ in_dim = bert_dim
118
+ for _ in range(config.gat_layers):
119
+ self.gat_layers.append(
120
+ GATConv(in_dim, config.gat_hidden, heads=config.gat_heads,
121
+ concat=True, dropout=config.dropout, edge_dim=1)
122
+ )
123
+ in_dim = config.gat_hidden * config.gat_heads
124
+
125
+ self.layer_norm = nn.LayerNorm(in_dim)
126
+ self.residual_proj = nn.Linear(bert_dim, in_dim)
127
+ self.classifier = nn.Linear(in_dim, config.num_labels)
128
+
129
+ def forward(
130
+ self,
131
+ input_ids: torch.Tensor,
132
+ attention_mask: torch.Tensor,
133
+ labels: Optional[torch.Tensor] = None,
134
+ ):
135
+ B, L = input_ids.shape
136
+ device = input_ids.device
137
+
138
+ bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
139
+ token_embs = bert_out.last_hidden_state
140
+ last_attn = bert_out.attentions[-1]
141
+
142
+ gat_outputs = []
143
+ for b in range(B):
144
+ seq_len = int(attention_mask[b].sum().item())
145
+ attn_b = last_attn[b, :, :seq_len, :seq_len]
146
+ edge_idx, edge_attr = _build_token_graph(
147
+ seq_len, attn_b, self.window, self.top_k, device,
148
+ )
149
+
150
+ h_real = token_embs[b, :seq_len]
151
+ h_res = h_real
152
+ for gat in self.gat_layers:
153
+ h_real = self.dropout(h_real)
154
+ h_real = gat(h_real, edge_idx, edge_attr=edge_attr)
155
+ h_real = F.elu(h_real)
156
+ h_real = self.layer_norm(h_real + self.residual_proj(h_res))
157
+
158
+ pad_len = L - seq_len
159
+ if pad_len > 0:
160
+ pad = torch.zeros(pad_len, h_real.shape[-1], device=device)
161
+ h_real = torch.cat([h_real, pad], dim=0)
162
+ gat_outputs.append(h_real)
163
+
164
+ gat_embs = torch.stack(gat_outputs, dim=0)
165
+ logits = self.classifier(self.dropout(gat_embs))
166
+
167
+ loss = None
168
+ if labels is not None:
169
+ loss = nn.CrossEntropyLoss(ignore_index=-100)(
170
+ logits.view(-1, self.config.num_labels), labels.view(-1)
171
+ )
172
+ return {"loss": loss, "logits": logits}
173
+
174
+ # ---- Convenience inference helpers -------------------------------------
175
+ @torch.no_grad()
176
+ def predict(
177
+ self,
178
+ text: str,
179
+ tokenizer: AutoTokenizer,
180
+ device: Optional[torch.device] = None,
181
+ ) -> dict:
182
+ device = device or next(self.parameters()).device
183
+ enc = tokenizer(
184
+ text,
185
+ return_tensors="pt",
186
+ return_offsets_mapping=True,
187
+ truncation=True,
188
+ max_length=self.config.max_length,
189
+ )
190
+ input_ids = enc["input_ids"].to(device)
191
+ attention_mask = enc["attention_mask"].to(device)
192
+ offsets = enc["offset_mapping"].squeeze(0).tolist()
193
+
194
+ out = self(input_ids, attention_mask)
195
+ preds = out["logits"].squeeze(0).argmax(dim=-1).cpu().tolist()
196
+
197
+ # preds and offsets are aligned 1:1 by index; iterate them
198
+ # together (zip-style) so that special tokens — whose offset is
199
+ # (0, 0) — and their matching prediction are skipped as a pair.
200
+ spans: List[dict] = []
201
+ cur_lbl, cur_start, cur_end = None, None, None
202
+ for pred_id, (tok_s, tok_e) in zip(preds, offsets):
203
+ if tok_s == tok_e:
204
+ continue
205
+ pred_lbl = ID2LABEL[pred_id]
206
+ if pred_lbl.startswith("B-"):
207
+ if cur_lbl:
208
+ spans.append({"start": cur_start, "end": cur_end, "label": cur_lbl})
209
+ cur_lbl, cur_start, cur_end = pred_lbl[2:], tok_s, tok_e
210
+ elif pred_lbl.startswith("I-") and cur_lbl == pred_lbl[2:]:
211
+ cur_end = tok_e
212
+ else:
213
+ if cur_lbl:
214
+ spans.append({"start": cur_start, "end": cur_end, "label": cur_lbl})
215
+ cur_lbl, cur_start, cur_end = None, None, None
216
+ if cur_lbl:
217
+ spans.append({"start": cur_start, "end": cur_end, "label": cur_lbl})
218
+
219
+ for sp in spans:
220
+ sp["value"] = text[sp["start"]:sp["end"]]
221
+
222
+ redacted = text
223
+ for sp in sorted(spans, key=lambda s: s["start"], reverse=True):
224
+ redacted = redacted[:sp["start"]] + f"[{sp['label']}]" + redacted[sp["end"]:]
225
+ return {"original": text, "redacted": redacted, "spans": spans}
226
+
227
+
228
+ # Hooks so AutoModel can find this class via the config's auto_map.
229
+ BertGATConfig.register_for_auto_class("AutoConfig")
230
+ BertGATForTokenClassification.register_for_auto_class("AutoModel")
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": true,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }