hadung1802 commited on
Commit
2c5b728
·
1 Parent(s): 8db6f43

Upload ViSoNorm trained model

Browse files
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ViSoNorm: Vietnamese Text Normalization Model
2
+
3
+ ViSoNorm is a state-of-the-art Vietnamese text normalization model that converts informal, non-standard Vietnamese text into standard Vietnamese. The model uses a multi-task learning approach with NSW (Non-Standard Word) detection, mask prediction, and lexical normalization heads.
4
+
5
+ ## Model Architecture
6
+
7
+ - **Base Model**: ViSoBERT (Vietnamese Social Media BERT)
8
+ - **Multi-task Heads**:
9
+ - NSW Detection: Identifies tokens that need normalization
10
+ - Mask Prediction: Determines how many masks to add for multi-token expansions
11
+ - Lexical Normalization: Predicts normalized tokens
12
+
13
+ ## Features
14
+
15
+ - **Self-contained inference**: Built-in `normalize_text` method
16
+ - **NSW detection**: Built-in `detect_nsw` method for detailed analysis
17
+ - **HuggingFace compatible**: Works seamlessly with `AutoModelForMaskedLM`
18
+ - **Production ready**: No hardcoded patterns, works for any Vietnamese text
19
+ - **Multi-token expansion**: Handles cases like "sv" → "sinh viên", "ctrai" → "con trai"
20
+ - **Confidence scoring**: Provides confidence scores for NSW detection and normalization
21
+
22
+ ## Installation
23
+
24
+ ```bash
25
+ pip install transformers torch
26
+ ```
27
+
28
+ ## Usage
29
+
30
+ ### Basic Usage
31
+
32
+ ```python
33
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
34
+
35
+ # Load model and tokenizer
36
+ model_repo = "hadung1802/visobert-normalizer"
37
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
38
+ model = AutoModelForMaskedLM.from_pretrained(model_repo, trust_remote_code=True)
39
+
40
+ # Normalize text
41
+ text = "sv dh gia dinh chua cho di lam :))"
42
+ normalized_text, source_tokens, predicted_tokens = model.normalize_text(
43
+ tokenizer, text, device='cpu'
44
+ )
45
+
46
+ print(f"Original: {text}")
47
+ print(f"Normalized: {normalized_text}")
48
+ ```
49
+
50
+ ### NSW Detection
51
+
52
+ ```python
53
+ # Detect Non-Standard Words (NSW) in text
54
+ text = "nhìn thôi cung thấy đau long quá đi :))"
55
+ nsw_results = model.detect_nsw(tokenizer, text, device='cpu')
56
+
57
+ print(f"Text: {text}")
58
+ for result in nsw_results:
59
+ print(f"NSW: '{result['nsw']}' → '{result['prediction']}' (confidence: {result['confidence_score']})")
60
+ ```
61
+
62
+ ### Batch Processing
63
+
64
+ ```python
65
+ texts = [
66
+ "sv dh gia dinh chua cho di lam :))",
67
+ "chúng nó bảo em là ctrai",
68
+ "t vs b chơi vs nhau đã lâu"
69
+ ]
70
+
71
+ for text in texts:
72
+ normalized_text, _, _ = model.normalize_text(tokenizer, text, device='cpu')
73
+ print(f"{text} → {normalized_text}")
74
+ ```
75
+
76
+ ### Expected Output
77
+
78
+ #### Text Normalization
79
+ ```
80
+ sv dh gia dinh chua cho di lam :)) → sinh viên đại học gia đình chưa cho đi làm :))
81
+ chúng nó bảo em là ctrai → chúng nó bảo em là con trai
82
+ t vs b chơi vs nhau đã lâu → tôi với bạn chơi với nhau đã lâu
83
+ ```
84
+
85
+ #### NSW Detection
86
+ ```python
87
+ # Input: "nhìn thôi cung thấy đau long quá đi :))"
88
+ [
89
+ {
90
+ "index": 3,
91
+ "start_index": 10,
92
+ "end_index": 14,
93
+ "nsw": "cung",
94
+ "prediction": "cũng",
95
+ "confidence_score": 0.9415
96
+ },
97
+ {
98
+ "index": 6,
99
+ "start_index": 24,
100
+ "end_index": 28,
101
+ "nsw": "long",
102
+ "prediction": "lòng",
103
+ "confidence_score": 0.7056
104
+ }
105
+ ]
106
+ ```
107
+
108
+ ### NSW Detection Output Format
109
+
110
+ The `detect_nsw` method returns a list of dictionaries with the following structure:
111
+
112
+ - **`index`**: Position of the token in the sequence
113
+ - **`start_index`**: Start character position in the original text
114
+ - **`end_index`**: End character position in the original text
115
+ - **`nsw`**: The original non-standard word (detokenized)
116
+ - **`prediction`**: The predicted normalized word (detokenized)
117
+ - **`confidence_score`**: Combined confidence score (0.0 to 1.0)
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<space>": 15002
3
+ }
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViSoNormViSoBERTForMaskedLM"
4
+ ],
5
+ "model_type": "xlm-roberta",
6
+ "vocab_size": 15003,
7
+ "pad_token_id": 1,
8
+ "bos_token_id": 0,
9
+ "eos_token_id": 2,
10
+ "mask_token_id": 3,
11
+ "mask_n_predictor": true,
12
+ "nsw_detector": true,
13
+ "auto_map": {
14
+ "AutoModel": "visonorm_visobert_model.ViSoNormViSoBERTForMaskedLM",
15
+ "AutoModelForMaskedLM": "visonorm_visobert_model.ViSoNormViSoBERTForMaskedLM"
16
+ },
17
+ "hidden_size": 768,
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "intermediate_size": 3072,
21
+ "max_position_embeddings": 514,
22
+ "type_vocab_size": 2,
23
+ "initializer_range": 0.02,
24
+ "layer_norm_eps": 1e-12
25
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c56f8433e6e26b0d6411414956a01eee2acc0c0471173b6e74e74a352b84993
3
+ size 393240883
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02aaf05cda4db99e86b7c76eba6258867ce4d043da0fed19c87a7d46c8b53a65
3
+ size 470732
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
state_dict_report.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "visobert",
3
+ "total_params": 213,
4
+ "expected_heads_present": {
5
+ "cls_decoder.weight": false,
6
+ "cls_decoder.bias": false,
7
+ "cls_dense.weight": false,
8
+ "cls_dense.bias": false,
9
+ "cls_layer_norm.weight": false,
10
+ "cls_layer_norm.bias": false,
11
+ "mask_n_predictor.mask_predictor_dense.weight": true,
12
+ "mask_n_predictor.mask_predictor_dense.bias": true,
13
+ "mask_n_predictor.mask_predictor_proj.weight": true,
14
+ "mask_n_predictor.mask_predictor_proj.bias": true,
15
+ "nsw_detector.dense.weight": true,
16
+ "nsw_detector.dense.bias": true,
17
+ "nsw_detector.predictor.weight": true,
18
+ "nsw_detector.predictor.bias": true
19
+ },
20
+ "alt_common_heads_present": {
21
+ "lm_head.weight": false,
22
+ "lm_head.bias": false,
23
+ "cls.decoder.weight": true,
24
+ "cls.decoder.bias": true,
25
+ "cls.dense.weight": true,
26
+ "cls.dense.bias": true,
27
+ "cls.layer_norm.weight": true,
28
+ "cls.layer_norm.bias": true
29
+ },
30
+ "aux_heads_present": {
31
+ "nsw_detector.": true,
32
+ "mask_n_predictor.": true
33
+ },
34
+ "example_keys": [
35
+ "roberta.embeddings.word_embeddings.weight",
36
+ "roberta.embeddings.position_embeddings.weight",
37
+ "roberta.embeddings.token_type_embeddings.weight",
38
+ "roberta.embeddings.LayerNorm.weight",
39
+ "roberta.embeddings.LayerNorm.bias",
40
+ "roberta.encoder.layer.0.attention.self.query.weight",
41
+ "roberta.encoder.layer.0.attention.self.query.bias",
42
+ "roberta.encoder.layer.0.attention.self.key.weight",
43
+ "roberta.encoder.layer.0.attention.self.key.bias",
44
+ "roberta.encoder.layer.0.attention.self.value.weight",
45
+ "roberta.encoder.layer.0.attention.self.value.bias",
46
+ "roberta.encoder.layer.0.attention.output.dense.weight",
47
+ "roberta.encoder.layer.0.attention.output.dense.bias",
48
+ "roberta.encoder.layer.0.attention.output.LayerNorm.weight",
49
+ "roberta.encoder.layer.0.attention.output.LayerNorm.bias",
50
+ "roberta.encoder.layer.0.intermediate.dense.weight",
51
+ "roberta.encoder.layer.0.intermediate.dense.bias",
52
+ "roberta.encoder.layer.0.output.dense.weight",
53
+ "roberta.encoder.layer.0.output.dense.bias",
54
+ "roberta.encoder.layer.0.output.LayerNorm.weight"
55
+ ]
56
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "XLMRobertaTokenizer",
3
+ "model_max_length": 512,
4
+ "padding_side": "right",
5
+ "truncation_side": "right",
6
+ "pad_token": "<pad>",
7
+ "bos_token": "<s>",
8
+ "eos_token": "</s>",
9
+ "unk_token": "<unk>",
10
+ "mask_token": "<mask>",
11
+ "additional_special_tokens": [
12
+ "<pad>",
13
+ "<s>",
14
+ "</s>",
15
+ "<unk>",
16
+ "<mask>"
17
+ ],
18
+ "use_fast": false
19
+ }
training_args.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "visobert",
3
+ "training_mode": "weakly_supervised",
4
+ "learning_rate": 0.001,
5
+ "num_epochs": 10,
6
+ "train_batch_size": 16,
7
+ "eval_batch_size": 128,
8
+ "remove_accents": false,
9
+ "lower_case": false
10
+ }
visonorm_visobert_model.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom ViSoNorm model class for ViSoBERT-based models.
4
+ This preserves the custom heads needed for text normalization and
5
+ is loadable via auto_map without custom model_type.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import XLMRobertaModel, XLMRobertaConfig, XLMRobertaPreTrainedModel
12
+ from transformers.modeling_outputs import MaskedLMOutput
13
+ # Define constants locally to avoid external dependencies
14
+ NUM_LABELS_N_MASKS = 5
15
+
16
+
17
+ def gelu(x):
18
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
19
+
20
+
21
+ class XLMRobertaLMHead(nn.Module):
22
+ def __init__(self, config, xlmroberta_model_embedding_weights):
23
+ super().__init__()
24
+ # Use the actual hidden size from the pretrained model, not the config
25
+ actual_hidden_size = xlmroberta_model_embedding_weights.size(1)
26
+ self.dense = nn.Linear(actual_hidden_size, actual_hidden_size)
27
+ self.layer_norm = nn.LayerNorm(actual_hidden_size, eps=1e-12)
28
+
29
+ num_labels = xlmroberta_model_embedding_weights.size(0)
30
+ self.decoder = nn.Linear(actual_hidden_size, num_labels, bias=False)
31
+ self.decoder.weight = xlmroberta_model_embedding_weights
32
+ self.decoder.bias = nn.Parameter(torch.zeros(num_labels))
33
+
34
+ def forward(self, features):
35
+ x = self.dense(features)
36
+ x = gelu(x)
37
+ x = self.layer_norm(x)
38
+ x = self.decoder(x)
39
+ return x
40
+
41
+
42
+ class XLMRobertaMaskNPredictionHead(nn.Module):
43
+ def __init__(self, config, actual_hidden_size):
44
+ super(XLMRobertaMaskNPredictionHead, self).__init__()
45
+ self.mask_predictor_dense = nn.Linear(actual_hidden_size, 50)
46
+ self.mask_predictor_proj = nn.Linear(50, NUM_LABELS_N_MASKS)
47
+ self.activation = gelu
48
+
49
+ def forward(self, sequence_output):
50
+ mask_predictor_state = self.activation(self.mask_predictor_dense(sequence_output))
51
+ prediction_scores = self.mask_predictor_proj(mask_predictor_state)
52
+ return prediction_scores
53
+
54
+
55
+ class XLMRobertaBinaryPredictor(nn.Module):
56
+ def __init__(self, hidden_size, dense_dim=100):
57
+ super(XLMRobertaBinaryPredictor, self).__init__()
58
+ self.dense = nn.Linear(hidden_size, dense_dim)
59
+ # Use 'predictor' to match the checkpoint parameter names
60
+ self.predictor = nn.Linear(dense_dim, 2)
61
+ self.activation = gelu
62
+
63
+ def forward(self, sequence_output):
64
+ state = self.activation(self.dense(sequence_output))
65
+ prediction_scores = self.predictor(state)
66
+ return prediction_scores
67
+
68
+
69
+ class ViSoNormViSoBERTForMaskedLM(XLMRobertaPreTrainedModel):
70
+ config_class = XLMRobertaConfig
71
+
72
+ def __init__(self, config: XLMRobertaConfig):
73
+ super().__init__(config)
74
+ self.roberta = XLMRobertaModel(config)
75
+
76
+ # Get actual hidden size from the pretrained model
77
+ actual_hidden_size = self.roberta.embeddings.word_embeddings.weight.size(1)
78
+
79
+ # ViSoNorm normalization head - use exact same structure as training
80
+ self.cls = XLMRobertaLMHead(config, self.roberta.embeddings.word_embeddings.weight)
81
+
82
+ # Additional heads for ViSoNorm functionality
83
+ self.mask_n_predictor = XLMRobertaMaskNPredictionHead(config, actual_hidden_size)
84
+ self.nsw_detector = XLMRobertaBinaryPredictor(actual_hidden_size, dense_dim=100)
85
+ self.num_labels_n_mask = NUM_LABELS_N_MASKS
86
+
87
+ # Initialize per HF conventions
88
+ self.post_init()
89
+
90
+ def forward(
91
+ self,
92
+ input_ids=None,
93
+ attention_mask=None,
94
+ token_type_ids=None,
95
+ position_ids=None,
96
+ head_mask=None,
97
+ inputs_embeds=None,
98
+ encoder_hidden_states=None,
99
+ encoder_attention_mask=None,
100
+ labels=None,
101
+ output_attentions=None,
102
+ output_hidden_states=None,
103
+ return_dict=None,
104
+ ):
105
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106
+
107
+ outputs = self.roberta(
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ token_type_ids=token_type_ids,
111
+ position_ids=position_ids,
112
+ head_mask=head_mask,
113
+ inputs_embeds=inputs_embeds,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ encoder_attention_mask=encoder_attention_mask,
116
+ output_attentions=output_attentions,
117
+ output_hidden_states=output_hidden_states,
118
+ return_dict=return_dict,
119
+ )
120
+
121
+ sequence_output = outputs[0]
122
+
123
+ # Calculate all three prediction heads
124
+ logits_norm = self.cls(sequence_output)
125
+ logits_n_masks_pred = self.mask_n_predictor(sequence_output)
126
+ logits_nsw_detection = self.nsw_detector(sequence_output)
127
+
128
+ if not return_dict:
129
+ return (logits_norm, logits_n_masks_pred, logits_nsw_detection) + outputs[1:]
130
+
131
+ # Return all prediction heads for ViSoNorm inference
132
+ # Create a custom output object that contains all three heads
133
+ class ViSoNormOutput:
134
+ def __init__(self, logits_norm, logits_n_masks_pred, logits_nsw_detection, hidden_states=None, attentions=None):
135
+ self.logits = logits_norm
136
+ self.logits_norm = logits_norm
137
+ self.logits_n_masks_pred = logits_n_masks_pred
138
+ self.logits_nsw_detection = logits_nsw_detection
139
+ self.hidden_states = hidden_states
140
+ self.attentions = attentions
141
+
142
+ return ViSoNormOutput(
143
+ logits_norm=logits_norm,
144
+ logits_n_masks_pred=logits_n_masks_pred,
145
+ logits_nsw_detection=logits_nsw_detection,
146
+ hidden_states=outputs.hidden_states,
147
+ attentions=outputs.attentions,
148
+ )
149
+
150
+ def normalize_text(self, tokenizer, text, device='cpu'):
151
+ """
152
+ Normalize text using the ViSoNorm ViSoBERT model with proper NSW detection and masking.
153
+
154
+ Args:
155
+ tokenizer: HuggingFace tokenizer
156
+ text: Input text to normalize
157
+ device: Device to run inference on
158
+
159
+ Returns:
160
+ Tuple of (normalized_text, source_tokens, prediction_tokens)
161
+ """
162
+ # Move model to device
163
+ self.to(device)
164
+
165
+ # Step 1: Preprocess text exactly like training data
166
+ # Tokenize the input text into tokens (not IDs yet)
167
+ input_tokens = tokenizer.tokenize(text)
168
+
169
+ # Add special tokens like in training
170
+ input_tokens = ['<s>'] + input_tokens + ['</s>']
171
+
172
+ # Convert tokens to IDs
173
+ input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
174
+ input_tokens_tensor = torch.LongTensor([input_ids]).to(device)
175
+
176
+ # Step 2: Apply the same truncation and masking logic as training
177
+ input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
178
+
179
+ # Step 3: Get all three prediction heads from ViSoNorm model
180
+ self.eval()
181
+ with torch.no_grad():
182
+ if hasattr(self, 'roberta'):
183
+ outputs = self(input_tokens_tensor, token_type_ids, input_mask)
184
+ else:
185
+ outputs = self(input_tokens_tensor, input_mask)
186
+
187
+ # Step 4: Use NSW detector to identify tokens that need normalization
188
+ tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
189
+
190
+ if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
191
+ # Handle different output shapes
192
+ if outputs.logits_nsw_detection.dim() == 3: # (batch, seq_len, 2) - binary classification
193
+ nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
194
+ else: # (batch, seq_len) - single output
195
+ nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
196
+
197
+ tokens_need_norm = []
198
+ for i, token in enumerate(tokens):
199
+ # Skip special tokens
200
+ if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
201
+ tokens_need_norm.append(False)
202
+ else:
203
+ if i < len(nsw_predictions):
204
+ tokens_need_norm.append(nsw_predictions[i].item())
205
+ else:
206
+ tokens_need_norm.append(False)
207
+ else:
208
+ # Fallback: assume all non-special tokens need checking
209
+ tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
210
+
211
+ # Update NSW tokens list (purely model-driven or generic non-special fallback)
212
+ nsw_tokens = [tokens[i] for i, need in enumerate(tokens_need_norm) if need]
213
+
214
+ # Step 5: Greedy 0/1-mask selection when heads are unusable
215
+ # Try, per NSW position, whether adding one mask improves sequence likelihood
216
+
217
+ def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
218
+ with torch.no_grad():
219
+ scored = self(input_ids=input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor))
220
+ logits = scored.logits_norm if hasattr(scored, 'logits_norm') else scored.logits
221
+ log_probs = torch.log_softmax(logits[0], dim=-1)
222
+ # Score by taking the max log-prob at each position (approximate sequence likelihood)
223
+ position_scores, _ = torch.max(log_probs, dim=-1)
224
+ return float(position_scores.mean().item())
225
+
226
+ mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
227
+ working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
228
+ nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
229
+
230
+ offset = 0
231
+ for i in nsw_indices:
232
+ pos = i + offset
233
+ # Candidate A: no mask
234
+ cand_a = working_ids
235
+ score_a = _score_sequence(torch.tensor([cand_a], device=device))
236
+ # Candidate B: add one mask after pos
237
+ cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
238
+ score_b = _score_sequence(torch.tensor([cand_b], device=device))
239
+ if score_b > score_a:
240
+ working_ids = cand_b
241
+ offset += 1
242
+
243
+ # Final prediction on the chosen masked sequence (may be unchanged)
244
+ masked_input_ids = torch.tensor([working_ids], device=device)
245
+ with torch.no_grad():
246
+ final_outputs = self(input_ids=masked_input_ids, attention_mask=torch.ones_like(masked_input_ids))
247
+ logits_final = final_outputs.logits_norm if hasattr(final_outputs, 'logits_norm') else final_outputs.logits
248
+ pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
249
+
250
+ # Build final token ids by taking predictions at positions; keep originals at specials
251
+ final_tokens = []
252
+ for idx, src_id in enumerate(working_ids):
253
+ tok = tokenizer.convert_ids_to_tokens([src_id])[0]
254
+ if tok in ['<s>', '</s>', '<pad>', '<unk>']:
255
+ final_tokens.append(src_id)
256
+ else:
257
+ final_tokens.append(pred_ids[idx] if idx < len(pred_ids) else src_id)
258
+
259
+ # Step 9: Convert to final text
260
+ def remove_special_tokens(token_list):
261
+ special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
262
+ return [token for token in token_list if token not in special_tokens]
263
+
264
+ def _safe_ids_to_text(token_ids):
265
+ if not token_ids:
266
+ return ""
267
+ try:
268
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
269
+ cleaned = remove_special_tokens(tokens)
270
+ if not cleaned:
271
+ return ""
272
+ return tokenizer.convert_tokens_to_string(cleaned)
273
+ except Exception:
274
+ return ""
275
+
276
+ # Build final normalized text
277
+ final_tokens = [tid for tid in final_tokens if tid != -1]
278
+ pred_str = _safe_ids_to_text(final_tokens)
279
+ # Collapse repeated whitespace
280
+ if pred_str:
281
+ pred_str = ' '.join(pred_str.split())
282
+
283
+ # Also return token lists for optional inspection
284
+ decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
285
+ decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
286
+
287
+ return pred_str, decoded_source, decoded_pred
288
+
289
+ def detect_nsw(self, tokenizer, text, device='cpu'):
290
+ """
291
+ Detect Non-Standard Words (NSW) in text and return detailed information.
292
+ This method aligns with normalize_text to ensure consistent NSW detection.
293
+
294
+ Args:
295
+ tokenizer: HuggingFace tokenizer
296
+ text: Input text to analyze
297
+ device: Device to run inference on
298
+
299
+ Returns:
300
+ List of dictionaries containing NSW information:
301
+ [{'index': int, 'start_index': int, 'end_index': int, 'nsw': str,
302
+ 'prediction': str, 'confidence_score': float}, ...]
303
+ """
304
+ # Move model to device
305
+ self.to(device)
306
+
307
+ # Step 1: Preprocess text exactly like normalize_text
308
+ input_tokens = tokenizer.tokenize(text)
309
+ input_tokens = ['<s>'] + input_tokens + ['</s>']
310
+ input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
311
+ input_tokens_tensor = torch.LongTensor([input_ids]).to(device)
312
+
313
+ # Step 2: Apply the same truncation and masking logic as normalize_text
314
+ input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
315
+
316
+ # Step 3: Get all three prediction heads from ViSoNorm model (same as normalize_text)
317
+ self.eval()
318
+ with torch.no_grad():
319
+ if hasattr(self, 'roberta'):
320
+ outputs = self(input_tokens_tensor, token_type_ids, input_mask)
321
+ else:
322
+ outputs = self(input_tokens_tensor, input_mask)
323
+
324
+ # Step 4: Use NSW detector to identify tokens that need normalization (same logic as normalize_text)
325
+ tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
326
+
327
+ if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
328
+ # Handle different output shapes (same as normalize_text)
329
+ if outputs.logits_nsw_detection.dim() == 3: # (batch, seq_len, 2) - binary classification
330
+ nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
331
+ nsw_confidence = torch.softmax(outputs.logits_nsw_detection[0], dim=-1)[:, 1]
332
+ else: # (batch, seq_len) - single output
333
+ nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
334
+ nsw_confidence = torch.sigmoid(outputs.logits_nsw_detection[0])
335
+
336
+ tokens_need_norm = []
337
+ for i, token in enumerate(tokens):
338
+ # Skip special tokens (same as normalize_text)
339
+ if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
340
+ tokens_need_norm.append(False)
341
+ else:
342
+ if i < len(nsw_predictions):
343
+ tokens_need_norm.append(nsw_predictions[i].item())
344
+ else:
345
+ tokens_need_norm.append(False)
346
+ else:
347
+ # Fallback: assume all non-special tokens need checking (same as normalize_text)
348
+ tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
349
+
350
+ # Step 5: Apply the same masking strategy as normalize_text
351
+ def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
352
+ with torch.no_grad():
353
+ scored = self(input_ids=input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor))
354
+ logits = scored.logits_norm if hasattr(scored, 'logits_norm') else scored.logits
355
+ log_probs = torch.log_softmax(logits[0], dim=-1)
356
+ position_scores, _ = torch.max(log_probs, dim=-1)
357
+ return float(position_scores.mean().item())
358
+
359
+ mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
360
+ working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
361
+ nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
362
+
363
+ offset = 0
364
+ for i in nsw_indices:
365
+ pos = i + offset
366
+ # Candidate A: no mask
367
+ cand_a = working_ids
368
+ score_a = _score_sequence(torch.tensor([cand_a], device=device))
369
+ # Candidate B: add one mask after pos
370
+ cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
371
+ score_b = _score_sequence(torch.tensor([cand_b], device=device))
372
+ if score_b > score_a:
373
+ working_ids = cand_b
374
+ offset += 1
375
+
376
+ # Step 6: Get final predictions using the same masked sequence as normalize_text
377
+ masked_input_ids = torch.tensor([working_ids], device=device)
378
+ with torch.no_grad():
379
+ final_outputs = self(input_ids=masked_input_ids, attention_mask=torch.ones_like(masked_input_ids))
380
+ logits_final = final_outputs.logits_norm if hasattr(final_outputs, 'logits_norm') else final_outputs.logits
381
+ pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
382
+
383
+ # Step 7: Build results using the same logic as normalize_text
384
+ # We need to identify NSW tokens by comparing original vs predicted tokens
385
+ # This ensures we catch all tokens that were actually changed, not just those detected by NSW head
386
+ nsw_results = []
387
+
388
+ # Build final token ids by taking predictions at positions; keep originals at specials (same as normalize_text)
389
+ final_tokens = []
390
+ for idx, src_id in enumerate(working_ids):
391
+ tok = tokenizer.convert_ids_to_tokens([src_id])[0]
392
+ if tok in ['<s>', '</s>', '<pad>', '<unk>']:
393
+ final_tokens.append(src_id)
394
+ else:
395
+ final_tokens.append(pred_ids[idx] if idx < len(pred_ids) else src_id)
396
+
397
+ # Convert final tokens to normalized text (same as normalize_text)
398
+ def remove_special_tokens(token_list):
399
+ special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
400
+ return [token for token in token_list if token not in special_tokens]
401
+
402
+ def _safe_ids_to_text(token_ids):
403
+ if not token_ids:
404
+ return ""
405
+ try:
406
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
407
+ cleaned = remove_special_tokens(tokens)
408
+ if not cleaned:
409
+ return ""
410
+ return tokenizer.convert_tokens_to_string(cleaned)
411
+ except Exception:
412
+ return ""
413
+
414
+ # Build final normalized text
415
+ final_tokens_cleaned = [tid for tid in final_tokens if tid != -1]
416
+ normalized_text = _safe_ids_to_text(final_tokens_cleaned)
417
+ # Collapse repeated whitespace
418
+ if normalized_text:
419
+ normalized_text = ' '.join(normalized_text.split())
420
+
421
+ # Now compare original text tokens with normalized text tokens
422
+ original_tokens = tokenizer.tokenize(text)
423
+ normalized_tokens = tokenizer.tokenize(normalized_text)
424
+
425
+ # Use a smarter approach that can handle multi-token expansions
426
+ # Get the source and predicted tokens from the model
427
+ decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
428
+ decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
429
+
430
+ # Clean the tokens (remove special tokens and ▁ prefix)
431
+ def clean_token(token):
432
+ if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
433
+ return None
434
+ return token.strip().lstrip('▁')
435
+
436
+ # Group consecutive predictions that form expansions
437
+ i = 0
438
+ while i < len(decoded_source):
439
+ src_token = decoded_source[i]
440
+ clean_src = clean_token(src_token)
441
+
442
+ if clean_src is None:
443
+ i += 1
444
+ continue
445
+
446
+ # Check if this token was changed
447
+ pred_token = decoded_pred[i]
448
+ clean_pred = clean_token(pred_token)
449
+
450
+ if clean_pred is None:
451
+ i += 1
452
+ continue
453
+
454
+ if clean_src != clean_pred:
455
+ # This is an NSW token - check if it's part of an expansion
456
+ expansion_tokens = [clean_pred]
457
+ j = i + 1
458
+
459
+ # Look for consecutive mask tokens that were filled
460
+ while j < len(decoded_source) and j < len(decoded_pred):
461
+ next_src = decoded_source[j]
462
+ next_pred = decoded_pred[j]
463
+
464
+ # If the source is a mask token, it was added for expansion
465
+ if next_src == '<mask>':
466
+ clean_next_pred = clean_token(next_pred)
467
+ if clean_next_pred is not None:
468
+ expansion_tokens.append(clean_next_pred)
469
+ j += 1
470
+ else:
471
+ # Check if the next source token was also changed
472
+ clean_next_src = clean_token(next_src)
473
+ clean_next_pred = clean_token(next_pred)
474
+
475
+ if clean_next_src is not None and clean_next_pred is not None and clean_next_src != clean_next_pred:
476
+ # This is also a changed token, might be part of expansion
477
+ # But we need to be careful not to group unrelated changes
478
+ # For now, let's be conservative and only group mask-based expansions
479
+ break
480
+ else:
481
+ break
482
+
483
+ # Create the expansion text
484
+ expansion_text = ' '.join(expansion_tokens)
485
+
486
+ # This is an NSW token
487
+ start_idx = text.find(clean_src)
488
+ end_idx = start_idx + len(clean_src) if start_idx != -1 else len(clean_src)
489
+
490
+ # Calculate confidence score
491
+ if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
492
+ # Find the corresponding position in the original token list
493
+ orig_pos = None
494
+ for k, tok in enumerate(tokens):
495
+ if tok.strip().lstrip('▁') == clean_src:
496
+ orig_pos = k
497
+ break
498
+
499
+ if orig_pos is not None and orig_pos < len(nsw_confidence):
500
+ if outputs.logits_nsw_detection.dim() == 3:
501
+ nsw_conf = nsw_confidence[orig_pos].item()
502
+ else:
503
+ nsw_conf = nsw_confidence[orig_pos].item()
504
+ else:
505
+ nsw_conf = 0.5 # Default if position not found
506
+
507
+ # Get normalization confidence
508
+ norm_logits = logits_final[0] # Use final masked logits
509
+ norm_confidence = torch.softmax(norm_logits, dim=-1)
510
+ norm_conf = norm_confidence[i][final_tokens[i]].item()
511
+ combined_confidence = (nsw_conf + norm_conf) / 2
512
+ else:
513
+ combined_confidence = 0.5 # Default confidence if no NSW detector
514
+
515
+ nsw_results.append({
516
+ 'index': i,
517
+ 'start_index': start_idx,
518
+ 'end_index': end_idx,
519
+ 'nsw': clean_src,
520
+ 'prediction': expansion_text,
521
+ 'confidence_score': round(combined_confidence, 4)
522
+ })
523
+
524
+ # Move to the next unprocessed token
525
+ i = j
526
+ else:
527
+ i += 1
528
+
529
+ return nsw_results
530
+
531
+ def _truncate_and_build_masks(self, input_tokens_tensor, output_tokens_tensor=None):
532
+ """Apply the same truncation and masking logic as training."""
533
+ if hasattr(self, 'roberta'):
534
+ cfg_max = int(getattr(self.roberta.config, 'max_position_embeddings', input_tokens_tensor.size(1)))
535
+ tbl_max = int(getattr(self.roberta.embeddings.position_embeddings, 'num_embeddings', cfg_max))
536
+ max_pos = min(cfg_max, tbl_max)
537
+ eff_max = max(1, max_pos - 2)
538
+ if input_tokens_tensor.size(1) > eff_max:
539
+ input_tokens_tensor = input_tokens_tensor[:, :eff_max]
540
+ if output_tokens_tensor is not None and output_tokens_tensor.dim() == 2 and output_tokens_tensor.size(1) > eff_max:
541
+ output_tokens_tensor = output_tokens_tensor[:, :eff_max]
542
+ pad_id_model = getattr(self.roberta.config, 'pad_token_id', None)
543
+ if pad_id_model is None:
544
+ pad_id_model = getattr(self.roberta.embeddings.word_embeddings, 'padding_idx', None)
545
+ if pad_id_model is None:
546
+ pad_id_model = 1 # Default pad token ID
547
+ input_mask = (input_tokens_tensor != pad_id_model).long()
548
+ token_type_ids = torch.zeros_like(input_tokens_tensor)
549
+ return input_tokens_tensor, output_tokens_tensor, token_type_ids, input_mask
550
+ # bart branch
551
+ pad_id_model = 1
552
+ input_mask = torch.ones_like(input_tokens_tensor)
553
+ token_type_ids = None
554
+ return input_tokens_tensor, output_tokens_tensor, token_type_ids, input_mask
555
+
556
+
557
+ __all__ = ["ViSoNormViSoBERTForMaskedLM"]