Commit ·
2c5b728
1
Parent(s): 8db6f43
Upload ViSoNorm trained model
Browse files- README.md +117 -0
- added_tokens.json +3 -0
- config.json +25 -0
- pytorch_model.bin +3 -0
- sentencepiece.bpe.model +3 -0
- special_tokens_map.json +15 -0
- state_dict_report.json +56 -0
- tokenizer_config.json +19 -0
- training_args.json +10 -0
- visonorm_visobert_model.py +557 -0
README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ViSoNorm: Vietnamese Text Normalization Model
|
| 2 |
+
|
| 3 |
+
ViSoNorm is a state-of-the-art Vietnamese text normalization model that converts informal, non-standard Vietnamese text into standard Vietnamese. The model uses a multi-task learning approach with NSW (Non-Standard Word) detection, mask prediction, and lexical normalization heads.
|
| 4 |
+
|
| 5 |
+
## Model Architecture
|
| 6 |
+
|
| 7 |
+
- **Base Model**: ViSoBERT (Vietnamese Social Media BERT)
|
| 8 |
+
- **Multi-task Heads**:
|
| 9 |
+
- NSW Detection: Identifies tokens that need normalization
|
| 10 |
+
- Mask Prediction: Determines how many masks to add for multi-token expansions
|
| 11 |
+
- Lexical Normalization: Predicts normalized tokens
|
| 12 |
+
|
| 13 |
+
## Features
|
| 14 |
+
|
| 15 |
+
- **Self-contained inference**: Built-in `normalize_text` method
|
| 16 |
+
- **NSW detection**: Built-in `detect_nsw` method for detailed analysis
|
| 17 |
+
- **HuggingFace compatible**: Works seamlessly with `AutoModelForMaskedLM`
|
| 18 |
+
- **Production ready**: No hardcoded patterns, works for any Vietnamese text
|
| 19 |
+
- **Multi-token expansion**: Handles cases like "sv" → "sinh viên", "ctrai" → "con trai"
|
| 20 |
+
- **Confidence scoring**: Provides confidence scores for NSW detection and normalization
|
| 21 |
+
|
| 22 |
+
## Installation
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
pip install transformers torch
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
### Basic Usage
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 34 |
+
|
| 35 |
+
# Load model and tokenizer
|
| 36 |
+
model_repo = "hadung1802/visobert-normalizer"
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 38 |
+
model = AutoModelForMaskedLM.from_pretrained(model_repo, trust_remote_code=True)
|
| 39 |
+
|
| 40 |
+
# Normalize text
|
| 41 |
+
text = "sv dh gia dinh chua cho di lam :))"
|
| 42 |
+
normalized_text, source_tokens, predicted_tokens = model.normalize_text(
|
| 43 |
+
tokenizer, text, device='cpu'
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
print(f"Original: {text}")
|
| 47 |
+
print(f"Normalized: {normalized_text}")
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### NSW Detection
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
# Detect Non-Standard Words (NSW) in text
|
| 54 |
+
text = "nhìn thôi cung thấy đau long quá đi :))"
|
| 55 |
+
nsw_results = model.detect_nsw(tokenizer, text, device='cpu')
|
| 56 |
+
|
| 57 |
+
print(f"Text: {text}")
|
| 58 |
+
for result in nsw_results:
|
| 59 |
+
print(f"NSW: '{result['nsw']}' → '{result['prediction']}' (confidence: {result['confidence_score']})")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Batch Processing
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
texts = [
|
| 66 |
+
"sv dh gia dinh chua cho di lam :))",
|
| 67 |
+
"chúng nó bảo em là ctrai",
|
| 68 |
+
"t vs b chơi vs nhau đã lâu"
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
for text in texts:
|
| 72 |
+
normalized_text, _, _ = model.normalize_text(tokenizer, text, device='cpu')
|
| 73 |
+
print(f"{text} → {normalized_text}")
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Expected Output
|
| 77 |
+
|
| 78 |
+
#### Text Normalization
|
| 79 |
+
```
|
| 80 |
+
sv dh gia dinh chua cho di lam :)) → sinh viên đại học gia đình chưa cho đi làm :))
|
| 81 |
+
chúng nó bảo em là ctrai → chúng nó bảo em là con trai
|
| 82 |
+
t vs b chơi vs nhau đã lâu → tôi với bạn chơi với nhau đã lâu
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
#### NSW Detection
|
| 86 |
+
```python
|
| 87 |
+
# Input: "nhìn thôi cung thấy đau long quá đi :))"
|
| 88 |
+
[
|
| 89 |
+
{
|
| 90 |
+
"index": 3,
|
| 91 |
+
"start_index": 10,
|
| 92 |
+
"end_index": 14,
|
| 93 |
+
"nsw": "cung",
|
| 94 |
+
"prediction": "cũng",
|
| 95 |
+
"confidence_score": 0.9415
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"index": 6,
|
| 99 |
+
"start_index": 24,
|
| 100 |
+
"end_index": 28,
|
| 101 |
+
"nsw": "long",
|
| 102 |
+
"prediction": "lòng",
|
| 103 |
+
"confidence_score": 0.7056
|
| 104 |
+
}
|
| 105 |
+
]
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### NSW Detection Output Format
|
| 109 |
+
|
| 110 |
+
The `detect_nsw` method returns a list of dictionaries with the following structure:
|
| 111 |
+
|
| 112 |
+
- **`index`**: Position of the token in the sequence
|
| 113 |
+
- **`start_index`**: Start character position in the original text
|
| 114 |
+
- **`end_index`**: End character position in the original text
|
| 115 |
+
- **`nsw`**: The original non-standard word (detokenized)
|
| 116 |
+
- **`prediction`**: The predicted normalized word (detokenized)
|
| 117 |
+
- **`confidence_score`**: Combined confidence score (0.0 to 1.0)
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<space>": 15002
|
| 3 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ViSoNormViSoBERTForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"model_type": "xlm-roberta",
|
| 6 |
+
"vocab_size": 15003,
|
| 7 |
+
"pad_token_id": 1,
|
| 8 |
+
"bos_token_id": 0,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"mask_token_id": 3,
|
| 11 |
+
"mask_n_predictor": true,
|
| 12 |
+
"nsw_detector": true,
|
| 13 |
+
"auto_map": {
|
| 14 |
+
"AutoModel": "visonorm_visobert_model.ViSoNormViSoBERTForMaskedLM",
|
| 15 |
+
"AutoModelForMaskedLM": "visonorm_visobert_model.ViSoNormViSoBERTForMaskedLM"
|
| 16 |
+
},
|
| 17 |
+
"hidden_size": 768,
|
| 18 |
+
"num_attention_heads": 12,
|
| 19 |
+
"num_hidden_layers": 12,
|
| 20 |
+
"intermediate_size": 3072,
|
| 21 |
+
"max_position_embeddings": 514,
|
| 22 |
+
"type_vocab_size": 2,
|
| 23 |
+
"initializer_range": 0.02,
|
| 24 |
+
"layer_norm_eps": 1e-12
|
| 25 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c56f8433e6e26b0d6411414956a01eee2acc0c0471173b6e74e74a352b84993
|
| 3 |
+
size 393240883
|
sentencepiece.bpe.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02aaf05cda4db99e86b7c76eba6258867ce4d043da0fed19c87a7d46c8b53a65
|
| 3 |
+
size 470732
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"cls_token": "<s>",
|
| 4 |
+
"eos_token": "</s>",
|
| 5 |
+
"mask_token": {
|
| 6 |
+
"content": "<mask>",
|
| 7 |
+
"lstrip": true,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"pad_token": "<pad>",
|
| 13 |
+
"sep_token": "</s>",
|
| 14 |
+
"unk_token": "<unk>"
|
| 15 |
+
}
|
state_dict_report.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "visobert",
|
| 3 |
+
"total_params": 213,
|
| 4 |
+
"expected_heads_present": {
|
| 5 |
+
"cls_decoder.weight": false,
|
| 6 |
+
"cls_decoder.bias": false,
|
| 7 |
+
"cls_dense.weight": false,
|
| 8 |
+
"cls_dense.bias": false,
|
| 9 |
+
"cls_layer_norm.weight": false,
|
| 10 |
+
"cls_layer_norm.bias": false,
|
| 11 |
+
"mask_n_predictor.mask_predictor_dense.weight": true,
|
| 12 |
+
"mask_n_predictor.mask_predictor_dense.bias": true,
|
| 13 |
+
"mask_n_predictor.mask_predictor_proj.weight": true,
|
| 14 |
+
"mask_n_predictor.mask_predictor_proj.bias": true,
|
| 15 |
+
"nsw_detector.dense.weight": true,
|
| 16 |
+
"nsw_detector.dense.bias": true,
|
| 17 |
+
"nsw_detector.predictor.weight": true,
|
| 18 |
+
"nsw_detector.predictor.bias": true
|
| 19 |
+
},
|
| 20 |
+
"alt_common_heads_present": {
|
| 21 |
+
"lm_head.weight": false,
|
| 22 |
+
"lm_head.bias": false,
|
| 23 |
+
"cls.decoder.weight": true,
|
| 24 |
+
"cls.decoder.bias": true,
|
| 25 |
+
"cls.dense.weight": true,
|
| 26 |
+
"cls.dense.bias": true,
|
| 27 |
+
"cls.layer_norm.weight": true,
|
| 28 |
+
"cls.layer_norm.bias": true
|
| 29 |
+
},
|
| 30 |
+
"aux_heads_present": {
|
| 31 |
+
"nsw_detector.": true,
|
| 32 |
+
"mask_n_predictor.": true
|
| 33 |
+
},
|
| 34 |
+
"example_keys": [
|
| 35 |
+
"roberta.embeddings.word_embeddings.weight",
|
| 36 |
+
"roberta.embeddings.position_embeddings.weight",
|
| 37 |
+
"roberta.embeddings.token_type_embeddings.weight",
|
| 38 |
+
"roberta.embeddings.LayerNorm.weight",
|
| 39 |
+
"roberta.embeddings.LayerNorm.bias",
|
| 40 |
+
"roberta.encoder.layer.0.attention.self.query.weight",
|
| 41 |
+
"roberta.encoder.layer.0.attention.self.query.bias",
|
| 42 |
+
"roberta.encoder.layer.0.attention.self.key.weight",
|
| 43 |
+
"roberta.encoder.layer.0.attention.self.key.bias",
|
| 44 |
+
"roberta.encoder.layer.0.attention.self.value.weight",
|
| 45 |
+
"roberta.encoder.layer.0.attention.self.value.bias",
|
| 46 |
+
"roberta.encoder.layer.0.attention.output.dense.weight",
|
| 47 |
+
"roberta.encoder.layer.0.attention.output.dense.bias",
|
| 48 |
+
"roberta.encoder.layer.0.attention.output.LayerNorm.weight",
|
| 49 |
+
"roberta.encoder.layer.0.attention.output.LayerNorm.bias",
|
| 50 |
+
"roberta.encoder.layer.0.intermediate.dense.weight",
|
| 51 |
+
"roberta.encoder.layer.0.intermediate.dense.bias",
|
| 52 |
+
"roberta.encoder.layer.0.output.dense.weight",
|
| 53 |
+
"roberta.encoder.layer.0.output.dense.bias",
|
| 54 |
+
"roberta.encoder.layer.0.output.LayerNorm.weight"
|
| 55 |
+
]
|
| 56 |
+
}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
| 3 |
+
"model_max_length": 512,
|
| 4 |
+
"padding_side": "right",
|
| 5 |
+
"truncation_side": "right",
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"bos_token": "<s>",
|
| 8 |
+
"eos_token": "</s>",
|
| 9 |
+
"unk_token": "<unk>",
|
| 10 |
+
"mask_token": "<mask>",
|
| 11 |
+
"additional_special_tokens": [
|
| 12 |
+
"<pad>",
|
| 13 |
+
"<s>",
|
| 14 |
+
"</s>",
|
| 15 |
+
"<unk>",
|
| 16 |
+
"<mask>"
|
| 17 |
+
],
|
| 18 |
+
"use_fast": false
|
| 19 |
+
}
|
training_args.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "visobert",
|
| 3 |
+
"training_mode": "weakly_supervised",
|
| 4 |
+
"learning_rate": 0.001,
|
| 5 |
+
"num_epochs": 10,
|
| 6 |
+
"train_batch_size": 16,
|
| 7 |
+
"eval_batch_size": 128,
|
| 8 |
+
"remove_accents": false,
|
| 9 |
+
"lower_case": false
|
| 10 |
+
}
|
visonorm_visobert_model.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Custom ViSoNorm model class for ViSoBERT-based models.
|
| 4 |
+
This preserves the custom heads needed for text normalization and
|
| 5 |
+
is loadable via auto_map without custom model_type.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from transformers import XLMRobertaModel, XLMRobertaConfig, XLMRobertaPreTrainedModel
|
| 12 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 13 |
+
# Define constants locally to avoid external dependencies
|
| 14 |
+
NUM_LABELS_N_MASKS = 5
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gelu(x):
|
| 18 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class XLMRobertaLMHead(nn.Module):
|
| 22 |
+
def __init__(self, config, xlmroberta_model_embedding_weights):
|
| 23 |
+
super().__init__()
|
| 24 |
+
# Use the actual hidden size from the pretrained model, not the config
|
| 25 |
+
actual_hidden_size = xlmroberta_model_embedding_weights.size(1)
|
| 26 |
+
self.dense = nn.Linear(actual_hidden_size, actual_hidden_size)
|
| 27 |
+
self.layer_norm = nn.LayerNorm(actual_hidden_size, eps=1e-12)
|
| 28 |
+
|
| 29 |
+
num_labels = xlmroberta_model_embedding_weights.size(0)
|
| 30 |
+
self.decoder = nn.Linear(actual_hidden_size, num_labels, bias=False)
|
| 31 |
+
self.decoder.weight = xlmroberta_model_embedding_weights
|
| 32 |
+
self.decoder.bias = nn.Parameter(torch.zeros(num_labels))
|
| 33 |
+
|
| 34 |
+
def forward(self, features):
|
| 35 |
+
x = self.dense(features)
|
| 36 |
+
x = gelu(x)
|
| 37 |
+
x = self.layer_norm(x)
|
| 38 |
+
x = self.decoder(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class XLMRobertaMaskNPredictionHead(nn.Module):
|
| 43 |
+
def __init__(self, config, actual_hidden_size):
|
| 44 |
+
super(XLMRobertaMaskNPredictionHead, self).__init__()
|
| 45 |
+
self.mask_predictor_dense = nn.Linear(actual_hidden_size, 50)
|
| 46 |
+
self.mask_predictor_proj = nn.Linear(50, NUM_LABELS_N_MASKS)
|
| 47 |
+
self.activation = gelu
|
| 48 |
+
|
| 49 |
+
def forward(self, sequence_output):
|
| 50 |
+
mask_predictor_state = self.activation(self.mask_predictor_dense(sequence_output))
|
| 51 |
+
prediction_scores = self.mask_predictor_proj(mask_predictor_state)
|
| 52 |
+
return prediction_scores
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class XLMRobertaBinaryPredictor(nn.Module):
|
| 56 |
+
def __init__(self, hidden_size, dense_dim=100):
|
| 57 |
+
super(XLMRobertaBinaryPredictor, self).__init__()
|
| 58 |
+
self.dense = nn.Linear(hidden_size, dense_dim)
|
| 59 |
+
# Use 'predictor' to match the checkpoint parameter names
|
| 60 |
+
self.predictor = nn.Linear(dense_dim, 2)
|
| 61 |
+
self.activation = gelu
|
| 62 |
+
|
| 63 |
+
def forward(self, sequence_output):
|
| 64 |
+
state = self.activation(self.dense(sequence_output))
|
| 65 |
+
prediction_scores = self.predictor(state)
|
| 66 |
+
return prediction_scores
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ViSoNormViSoBERTForMaskedLM(XLMRobertaPreTrainedModel):
|
| 70 |
+
config_class = XLMRobertaConfig
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: XLMRobertaConfig):
|
| 73 |
+
super().__init__(config)
|
| 74 |
+
self.roberta = XLMRobertaModel(config)
|
| 75 |
+
|
| 76 |
+
# Get actual hidden size from the pretrained model
|
| 77 |
+
actual_hidden_size = self.roberta.embeddings.word_embeddings.weight.size(1)
|
| 78 |
+
|
| 79 |
+
# ViSoNorm normalization head - use exact same structure as training
|
| 80 |
+
self.cls = XLMRobertaLMHead(config, self.roberta.embeddings.word_embeddings.weight)
|
| 81 |
+
|
| 82 |
+
# Additional heads for ViSoNorm functionality
|
| 83 |
+
self.mask_n_predictor = XLMRobertaMaskNPredictionHead(config, actual_hidden_size)
|
| 84 |
+
self.nsw_detector = XLMRobertaBinaryPredictor(actual_hidden_size, dense_dim=100)
|
| 85 |
+
self.num_labels_n_mask = NUM_LABELS_N_MASKS
|
| 86 |
+
|
| 87 |
+
# Initialize per HF conventions
|
| 88 |
+
self.post_init()
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
input_ids=None,
|
| 93 |
+
attention_mask=None,
|
| 94 |
+
token_type_ids=None,
|
| 95 |
+
position_ids=None,
|
| 96 |
+
head_mask=None,
|
| 97 |
+
inputs_embeds=None,
|
| 98 |
+
encoder_hidden_states=None,
|
| 99 |
+
encoder_attention_mask=None,
|
| 100 |
+
labels=None,
|
| 101 |
+
output_attentions=None,
|
| 102 |
+
output_hidden_states=None,
|
| 103 |
+
return_dict=None,
|
| 104 |
+
):
|
| 105 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 106 |
+
|
| 107 |
+
outputs = self.roberta(
|
| 108 |
+
input_ids=input_ids,
|
| 109 |
+
attention_mask=attention_mask,
|
| 110 |
+
token_type_ids=token_type_ids,
|
| 111 |
+
position_ids=position_ids,
|
| 112 |
+
head_mask=head_mask,
|
| 113 |
+
inputs_embeds=inputs_embeds,
|
| 114 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 115 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 116 |
+
output_attentions=output_attentions,
|
| 117 |
+
output_hidden_states=output_hidden_states,
|
| 118 |
+
return_dict=return_dict,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
sequence_output = outputs[0]
|
| 122 |
+
|
| 123 |
+
# Calculate all three prediction heads
|
| 124 |
+
logits_norm = self.cls(sequence_output)
|
| 125 |
+
logits_n_masks_pred = self.mask_n_predictor(sequence_output)
|
| 126 |
+
logits_nsw_detection = self.nsw_detector(sequence_output)
|
| 127 |
+
|
| 128 |
+
if not return_dict:
|
| 129 |
+
return (logits_norm, logits_n_masks_pred, logits_nsw_detection) + outputs[1:]
|
| 130 |
+
|
| 131 |
+
# Return all prediction heads for ViSoNorm inference
|
| 132 |
+
# Create a custom output object that contains all three heads
|
| 133 |
+
class ViSoNormOutput:
|
| 134 |
+
def __init__(self, logits_norm, logits_n_masks_pred, logits_nsw_detection, hidden_states=None, attentions=None):
|
| 135 |
+
self.logits = logits_norm
|
| 136 |
+
self.logits_norm = logits_norm
|
| 137 |
+
self.logits_n_masks_pred = logits_n_masks_pred
|
| 138 |
+
self.logits_nsw_detection = logits_nsw_detection
|
| 139 |
+
self.hidden_states = hidden_states
|
| 140 |
+
self.attentions = attentions
|
| 141 |
+
|
| 142 |
+
return ViSoNormOutput(
|
| 143 |
+
logits_norm=logits_norm,
|
| 144 |
+
logits_n_masks_pred=logits_n_masks_pred,
|
| 145 |
+
logits_nsw_detection=logits_nsw_detection,
|
| 146 |
+
hidden_states=outputs.hidden_states,
|
| 147 |
+
attentions=outputs.attentions,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def normalize_text(self, tokenizer, text, device='cpu'):
|
| 151 |
+
"""
|
| 152 |
+
Normalize text using the ViSoNorm ViSoBERT model with proper NSW detection and masking.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: HuggingFace tokenizer
|
| 156 |
+
text: Input text to normalize
|
| 157 |
+
device: Device to run inference on
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Tuple of (normalized_text, source_tokens, prediction_tokens)
|
| 161 |
+
"""
|
| 162 |
+
# Move model to device
|
| 163 |
+
self.to(device)
|
| 164 |
+
|
| 165 |
+
# Step 1: Preprocess text exactly like training data
|
| 166 |
+
# Tokenize the input text into tokens (not IDs yet)
|
| 167 |
+
input_tokens = tokenizer.tokenize(text)
|
| 168 |
+
|
| 169 |
+
# Add special tokens like in training
|
| 170 |
+
input_tokens = ['<s>'] + input_tokens + ['</s>']
|
| 171 |
+
|
| 172 |
+
# Convert tokens to IDs
|
| 173 |
+
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
| 174 |
+
input_tokens_tensor = torch.LongTensor([input_ids]).to(device)
|
| 175 |
+
|
| 176 |
+
# Step 2: Apply the same truncation and masking logic as training
|
| 177 |
+
input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
|
| 178 |
+
|
| 179 |
+
# Step 3: Get all three prediction heads from ViSoNorm model
|
| 180 |
+
self.eval()
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
if hasattr(self, 'roberta'):
|
| 183 |
+
outputs = self(input_tokens_tensor, token_type_ids, input_mask)
|
| 184 |
+
else:
|
| 185 |
+
outputs = self(input_tokens_tensor, input_mask)
|
| 186 |
+
|
| 187 |
+
# Step 4: Use NSW detector to identify tokens that need normalization
|
| 188 |
+
tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
|
| 189 |
+
|
| 190 |
+
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
| 191 |
+
# Handle different output shapes
|
| 192 |
+
if outputs.logits_nsw_detection.dim() == 3: # (batch, seq_len, 2) - binary classification
|
| 193 |
+
nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
|
| 194 |
+
else: # (batch, seq_len) - single output
|
| 195 |
+
nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
|
| 196 |
+
|
| 197 |
+
tokens_need_norm = []
|
| 198 |
+
for i, token in enumerate(tokens):
|
| 199 |
+
# Skip special tokens
|
| 200 |
+
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
| 201 |
+
tokens_need_norm.append(False)
|
| 202 |
+
else:
|
| 203 |
+
if i < len(nsw_predictions):
|
| 204 |
+
tokens_need_norm.append(nsw_predictions[i].item())
|
| 205 |
+
else:
|
| 206 |
+
tokens_need_norm.append(False)
|
| 207 |
+
else:
|
| 208 |
+
# Fallback: assume all non-special tokens need checking
|
| 209 |
+
tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
|
| 210 |
+
|
| 211 |
+
# Update NSW tokens list (purely model-driven or generic non-special fallback)
|
| 212 |
+
nsw_tokens = [tokens[i] for i, need in enumerate(tokens_need_norm) if need]
|
| 213 |
+
|
| 214 |
+
# Step 5: Greedy 0/1-mask selection when heads are unusable
|
| 215 |
+
# Try, per NSW position, whether adding one mask improves sequence likelihood
|
| 216 |
+
|
| 217 |
+
def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
scored = self(input_ids=input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor))
|
| 220 |
+
logits = scored.logits_norm if hasattr(scored, 'logits_norm') else scored.logits
|
| 221 |
+
log_probs = torch.log_softmax(logits[0], dim=-1)
|
| 222 |
+
# Score by taking the max log-prob at each position (approximate sequence likelihood)
|
| 223 |
+
position_scores, _ = torch.max(log_probs, dim=-1)
|
| 224 |
+
return float(position_scores.mean().item())
|
| 225 |
+
|
| 226 |
+
mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
|
| 227 |
+
working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
|
| 228 |
+
nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
|
| 229 |
+
|
| 230 |
+
offset = 0
|
| 231 |
+
for i in nsw_indices:
|
| 232 |
+
pos = i + offset
|
| 233 |
+
# Candidate A: no mask
|
| 234 |
+
cand_a = working_ids
|
| 235 |
+
score_a = _score_sequence(torch.tensor([cand_a], device=device))
|
| 236 |
+
# Candidate B: add one mask after pos
|
| 237 |
+
cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
|
| 238 |
+
score_b = _score_sequence(torch.tensor([cand_b], device=device))
|
| 239 |
+
if score_b > score_a:
|
| 240 |
+
working_ids = cand_b
|
| 241 |
+
offset += 1
|
| 242 |
+
|
| 243 |
+
# Final prediction on the chosen masked sequence (may be unchanged)
|
| 244 |
+
masked_input_ids = torch.tensor([working_ids], device=device)
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
final_outputs = self(input_ids=masked_input_ids, attention_mask=torch.ones_like(masked_input_ids))
|
| 247 |
+
logits_final = final_outputs.logits_norm if hasattr(final_outputs, 'logits_norm') else final_outputs.logits
|
| 248 |
+
pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
|
| 249 |
+
|
| 250 |
+
# Build final token ids by taking predictions at positions; keep originals at specials
|
| 251 |
+
final_tokens = []
|
| 252 |
+
for idx, src_id in enumerate(working_ids):
|
| 253 |
+
tok = tokenizer.convert_ids_to_tokens([src_id])[0]
|
| 254 |
+
if tok in ['<s>', '</s>', '<pad>', '<unk>']:
|
| 255 |
+
final_tokens.append(src_id)
|
| 256 |
+
else:
|
| 257 |
+
final_tokens.append(pred_ids[idx] if idx < len(pred_ids) else src_id)
|
| 258 |
+
|
| 259 |
+
# Step 9: Convert to final text
|
| 260 |
+
def remove_special_tokens(token_list):
|
| 261 |
+
special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
|
| 262 |
+
return [token for token in token_list if token not in special_tokens]
|
| 263 |
+
|
| 264 |
+
def _safe_ids_to_text(token_ids):
|
| 265 |
+
if not token_ids:
|
| 266 |
+
return ""
|
| 267 |
+
try:
|
| 268 |
+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
| 269 |
+
cleaned = remove_special_tokens(tokens)
|
| 270 |
+
if not cleaned:
|
| 271 |
+
return ""
|
| 272 |
+
return tokenizer.convert_tokens_to_string(cleaned)
|
| 273 |
+
except Exception:
|
| 274 |
+
return ""
|
| 275 |
+
|
| 276 |
+
# Build final normalized text
|
| 277 |
+
final_tokens = [tid for tid in final_tokens if tid != -1]
|
| 278 |
+
pred_str = _safe_ids_to_text(final_tokens)
|
| 279 |
+
# Collapse repeated whitespace
|
| 280 |
+
if pred_str:
|
| 281 |
+
pred_str = ' '.join(pred_str.split())
|
| 282 |
+
|
| 283 |
+
# Also return token lists for optional inspection
|
| 284 |
+
decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
|
| 285 |
+
decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
|
| 286 |
+
|
| 287 |
+
return pred_str, decoded_source, decoded_pred
|
| 288 |
+
|
| 289 |
+
def detect_nsw(self, tokenizer, text, device='cpu'):
|
| 290 |
+
"""
|
| 291 |
+
Detect Non-Standard Words (NSW) in text and return detailed information.
|
| 292 |
+
This method aligns with normalize_text to ensure consistent NSW detection.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
tokenizer: HuggingFace tokenizer
|
| 296 |
+
text: Input text to analyze
|
| 297 |
+
device: Device to run inference on
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
List of dictionaries containing NSW information:
|
| 301 |
+
[{'index': int, 'start_index': int, 'end_index': int, 'nsw': str,
|
| 302 |
+
'prediction': str, 'confidence_score': float}, ...]
|
| 303 |
+
"""
|
| 304 |
+
# Move model to device
|
| 305 |
+
self.to(device)
|
| 306 |
+
|
| 307 |
+
# Step 1: Preprocess text exactly like normalize_text
|
| 308 |
+
input_tokens = tokenizer.tokenize(text)
|
| 309 |
+
input_tokens = ['<s>'] + input_tokens + ['</s>']
|
| 310 |
+
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
| 311 |
+
input_tokens_tensor = torch.LongTensor([input_ids]).to(device)
|
| 312 |
+
|
| 313 |
+
# Step 2: Apply the same truncation and masking logic as normalize_text
|
| 314 |
+
input_tokens_tensor, _, token_type_ids, input_mask = self._truncate_and_build_masks(input_tokens_tensor)
|
| 315 |
+
|
| 316 |
+
# Step 3: Get all three prediction heads from ViSoNorm model (same as normalize_text)
|
| 317 |
+
self.eval()
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
if hasattr(self, 'roberta'):
|
| 320 |
+
outputs = self(input_tokens_tensor, token_type_ids, input_mask)
|
| 321 |
+
else:
|
| 322 |
+
outputs = self(input_tokens_tensor, input_mask)
|
| 323 |
+
|
| 324 |
+
# Step 4: Use NSW detector to identify tokens that need normalization (same logic as normalize_text)
|
| 325 |
+
tokens = tokenizer.convert_ids_to_tokens(input_tokens_tensor[0])
|
| 326 |
+
|
| 327 |
+
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
| 328 |
+
# Handle different output shapes (same as normalize_text)
|
| 329 |
+
if outputs.logits_nsw_detection.dim() == 3: # (batch, seq_len, 2) - binary classification
|
| 330 |
+
nsw_predictions = torch.argmax(outputs.logits_nsw_detection[0], dim=-1) == 1
|
| 331 |
+
nsw_confidence = torch.softmax(outputs.logits_nsw_detection[0], dim=-1)[:, 1]
|
| 332 |
+
else: # (batch, seq_len) - single output
|
| 333 |
+
nsw_predictions = torch.sigmoid(outputs.logits_nsw_detection[0]) > 0.5
|
| 334 |
+
nsw_confidence = torch.sigmoid(outputs.logits_nsw_detection[0])
|
| 335 |
+
|
| 336 |
+
tokens_need_norm = []
|
| 337 |
+
for i, token in enumerate(tokens):
|
| 338 |
+
# Skip special tokens (same as normalize_text)
|
| 339 |
+
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
| 340 |
+
tokens_need_norm.append(False)
|
| 341 |
+
else:
|
| 342 |
+
if i < len(nsw_predictions):
|
| 343 |
+
tokens_need_norm.append(nsw_predictions[i].item())
|
| 344 |
+
else:
|
| 345 |
+
tokens_need_norm.append(False)
|
| 346 |
+
else:
|
| 347 |
+
# Fallback: assume all non-special tokens need checking (same as normalize_text)
|
| 348 |
+
tokens_need_norm = [token not in ['<s>', '</s>', '<pad>', '<unk>', '<mask>'] for token in tokens]
|
| 349 |
+
|
| 350 |
+
# Step 5: Apply the same masking strategy as normalize_text
|
| 351 |
+
def _score_sequence(input_ids_tensor: torch.Tensor) -> float:
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
scored = self(input_ids=input_ids_tensor, attention_mask=torch.ones_like(input_ids_tensor))
|
| 354 |
+
logits = scored.logits_norm if hasattr(scored, 'logits_norm') else scored.logits
|
| 355 |
+
log_probs = torch.log_softmax(logits[0], dim=-1)
|
| 356 |
+
position_scores, _ = torch.max(log_probs, dim=-1)
|
| 357 |
+
return float(position_scores.mean().item())
|
| 358 |
+
|
| 359 |
+
mask_token_id = tokenizer.convert_tokens_to_ids('<mask>')
|
| 360 |
+
working_ids = input_tokens_tensor[0].detach().clone().cpu().tolist()
|
| 361 |
+
nsw_indices = [i for i, need in enumerate(tokens_need_norm) if need]
|
| 362 |
+
|
| 363 |
+
offset = 0
|
| 364 |
+
for i in nsw_indices:
|
| 365 |
+
pos = i + offset
|
| 366 |
+
# Candidate A: no mask
|
| 367 |
+
cand_a = working_ids
|
| 368 |
+
score_a = _score_sequence(torch.tensor([cand_a], device=device))
|
| 369 |
+
# Candidate B: add one mask after pos
|
| 370 |
+
cand_b = working_ids[:pos+1] + [mask_token_id] + working_ids[pos+1:]
|
| 371 |
+
score_b = _score_sequence(torch.tensor([cand_b], device=device))
|
| 372 |
+
if score_b > score_a:
|
| 373 |
+
working_ids = cand_b
|
| 374 |
+
offset += 1
|
| 375 |
+
|
| 376 |
+
# Step 6: Get final predictions using the same masked sequence as normalize_text
|
| 377 |
+
masked_input_ids = torch.tensor([working_ids], device=device)
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
final_outputs = self(input_ids=masked_input_ids, attention_mask=torch.ones_like(masked_input_ids))
|
| 380 |
+
logits_final = final_outputs.logits_norm if hasattr(final_outputs, 'logits_norm') else final_outputs.logits
|
| 381 |
+
pred_ids = torch.argmax(logits_final, dim=-1)[0].cpu().tolist()
|
| 382 |
+
|
| 383 |
+
# Step 7: Build results using the same logic as normalize_text
|
| 384 |
+
# We need to identify NSW tokens by comparing original vs predicted tokens
|
| 385 |
+
# This ensures we catch all tokens that were actually changed, not just those detected by NSW head
|
| 386 |
+
nsw_results = []
|
| 387 |
+
|
| 388 |
+
# Build final token ids by taking predictions at positions; keep originals at specials (same as normalize_text)
|
| 389 |
+
final_tokens = []
|
| 390 |
+
for idx, src_id in enumerate(working_ids):
|
| 391 |
+
tok = tokenizer.convert_ids_to_tokens([src_id])[0]
|
| 392 |
+
if tok in ['<s>', '</s>', '<pad>', '<unk>']:
|
| 393 |
+
final_tokens.append(src_id)
|
| 394 |
+
else:
|
| 395 |
+
final_tokens.append(pred_ids[idx] if idx < len(pred_ids) else src_id)
|
| 396 |
+
|
| 397 |
+
# Convert final tokens to normalized text (same as normalize_text)
|
| 398 |
+
def remove_special_tokens(token_list):
|
| 399 |
+
special_tokens = ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<space>']
|
| 400 |
+
return [token for token in token_list if token not in special_tokens]
|
| 401 |
+
|
| 402 |
+
def _safe_ids_to_text(token_ids):
|
| 403 |
+
if not token_ids:
|
| 404 |
+
return ""
|
| 405 |
+
try:
|
| 406 |
+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
| 407 |
+
cleaned = remove_special_tokens(tokens)
|
| 408 |
+
if not cleaned:
|
| 409 |
+
return ""
|
| 410 |
+
return tokenizer.convert_tokens_to_string(cleaned)
|
| 411 |
+
except Exception:
|
| 412 |
+
return ""
|
| 413 |
+
|
| 414 |
+
# Build final normalized text
|
| 415 |
+
final_tokens_cleaned = [tid for tid in final_tokens if tid != -1]
|
| 416 |
+
normalized_text = _safe_ids_to_text(final_tokens_cleaned)
|
| 417 |
+
# Collapse repeated whitespace
|
| 418 |
+
if normalized_text:
|
| 419 |
+
normalized_text = ' '.join(normalized_text.split())
|
| 420 |
+
|
| 421 |
+
# Now compare original text tokens with normalized text tokens
|
| 422 |
+
original_tokens = tokenizer.tokenize(text)
|
| 423 |
+
normalized_tokens = tokenizer.tokenize(normalized_text)
|
| 424 |
+
|
| 425 |
+
# Use a smarter approach that can handle multi-token expansions
|
| 426 |
+
# Get the source and predicted tokens from the model
|
| 427 |
+
decoded_source = tokenizer.convert_ids_to_tokens(working_ids)
|
| 428 |
+
decoded_pred = tokenizer.convert_ids_to_tokens(final_tokens)
|
| 429 |
+
|
| 430 |
+
# Clean the tokens (remove special tokens and ▁ prefix)
|
| 431 |
+
def clean_token(token):
|
| 432 |
+
if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
|
| 433 |
+
return None
|
| 434 |
+
return token.strip().lstrip('▁')
|
| 435 |
+
|
| 436 |
+
# Group consecutive predictions that form expansions
|
| 437 |
+
i = 0
|
| 438 |
+
while i < len(decoded_source):
|
| 439 |
+
src_token = decoded_source[i]
|
| 440 |
+
clean_src = clean_token(src_token)
|
| 441 |
+
|
| 442 |
+
if clean_src is None:
|
| 443 |
+
i += 1
|
| 444 |
+
continue
|
| 445 |
+
|
| 446 |
+
# Check if this token was changed
|
| 447 |
+
pred_token = decoded_pred[i]
|
| 448 |
+
clean_pred = clean_token(pred_token)
|
| 449 |
+
|
| 450 |
+
if clean_pred is None:
|
| 451 |
+
i += 1
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
if clean_src != clean_pred:
|
| 455 |
+
# This is an NSW token - check if it's part of an expansion
|
| 456 |
+
expansion_tokens = [clean_pred]
|
| 457 |
+
j = i + 1
|
| 458 |
+
|
| 459 |
+
# Look for consecutive mask tokens that were filled
|
| 460 |
+
while j < len(decoded_source) and j < len(decoded_pred):
|
| 461 |
+
next_src = decoded_source[j]
|
| 462 |
+
next_pred = decoded_pred[j]
|
| 463 |
+
|
| 464 |
+
# If the source is a mask token, it was added for expansion
|
| 465 |
+
if next_src == '<mask>':
|
| 466 |
+
clean_next_pred = clean_token(next_pred)
|
| 467 |
+
if clean_next_pred is not None:
|
| 468 |
+
expansion_tokens.append(clean_next_pred)
|
| 469 |
+
j += 1
|
| 470 |
+
else:
|
| 471 |
+
# Check if the next source token was also changed
|
| 472 |
+
clean_next_src = clean_token(next_src)
|
| 473 |
+
clean_next_pred = clean_token(next_pred)
|
| 474 |
+
|
| 475 |
+
if clean_next_src is not None and clean_next_pred is not None and clean_next_src != clean_next_pred:
|
| 476 |
+
# This is also a changed token, might be part of expansion
|
| 477 |
+
# But we need to be careful not to group unrelated changes
|
| 478 |
+
# For now, let's be conservative and only group mask-based expansions
|
| 479 |
+
break
|
| 480 |
+
else:
|
| 481 |
+
break
|
| 482 |
+
|
| 483 |
+
# Create the expansion text
|
| 484 |
+
expansion_text = ' '.join(expansion_tokens)
|
| 485 |
+
|
| 486 |
+
# This is an NSW token
|
| 487 |
+
start_idx = text.find(clean_src)
|
| 488 |
+
end_idx = start_idx + len(clean_src) if start_idx != -1 else len(clean_src)
|
| 489 |
+
|
| 490 |
+
# Calculate confidence score
|
| 491 |
+
if hasattr(outputs, 'logits_nsw_detection') and outputs.logits_nsw_detection is not None:
|
| 492 |
+
# Find the corresponding position in the original token list
|
| 493 |
+
orig_pos = None
|
| 494 |
+
for k, tok in enumerate(tokens):
|
| 495 |
+
if tok.strip().lstrip('▁') == clean_src:
|
| 496 |
+
orig_pos = k
|
| 497 |
+
break
|
| 498 |
+
|
| 499 |
+
if orig_pos is not None and orig_pos < len(nsw_confidence):
|
| 500 |
+
if outputs.logits_nsw_detection.dim() == 3:
|
| 501 |
+
nsw_conf = nsw_confidence[orig_pos].item()
|
| 502 |
+
else:
|
| 503 |
+
nsw_conf = nsw_confidence[orig_pos].item()
|
| 504 |
+
else:
|
| 505 |
+
nsw_conf = 0.5 # Default if position not found
|
| 506 |
+
|
| 507 |
+
# Get normalization confidence
|
| 508 |
+
norm_logits = logits_final[0] # Use final masked logits
|
| 509 |
+
norm_confidence = torch.softmax(norm_logits, dim=-1)
|
| 510 |
+
norm_conf = norm_confidence[i][final_tokens[i]].item()
|
| 511 |
+
combined_confidence = (nsw_conf + norm_conf) / 2
|
| 512 |
+
else:
|
| 513 |
+
combined_confidence = 0.5 # Default confidence if no NSW detector
|
| 514 |
+
|
| 515 |
+
nsw_results.append({
|
| 516 |
+
'index': i,
|
| 517 |
+
'start_index': start_idx,
|
| 518 |
+
'end_index': end_idx,
|
| 519 |
+
'nsw': clean_src,
|
| 520 |
+
'prediction': expansion_text,
|
| 521 |
+
'confidence_score': round(combined_confidence, 4)
|
| 522 |
+
})
|
| 523 |
+
|
| 524 |
+
# Move to the next unprocessed token
|
| 525 |
+
i = j
|
| 526 |
+
else:
|
| 527 |
+
i += 1
|
| 528 |
+
|
| 529 |
+
return nsw_results
|
| 530 |
+
|
| 531 |
+
def _truncate_and_build_masks(self, input_tokens_tensor, output_tokens_tensor=None):
|
| 532 |
+
"""Apply the same truncation and masking logic as training."""
|
| 533 |
+
if hasattr(self, 'roberta'):
|
| 534 |
+
cfg_max = int(getattr(self.roberta.config, 'max_position_embeddings', input_tokens_tensor.size(1)))
|
| 535 |
+
tbl_max = int(getattr(self.roberta.embeddings.position_embeddings, 'num_embeddings', cfg_max))
|
| 536 |
+
max_pos = min(cfg_max, tbl_max)
|
| 537 |
+
eff_max = max(1, max_pos - 2)
|
| 538 |
+
if input_tokens_tensor.size(1) > eff_max:
|
| 539 |
+
input_tokens_tensor = input_tokens_tensor[:, :eff_max]
|
| 540 |
+
if output_tokens_tensor is not None and output_tokens_tensor.dim() == 2 and output_tokens_tensor.size(1) > eff_max:
|
| 541 |
+
output_tokens_tensor = output_tokens_tensor[:, :eff_max]
|
| 542 |
+
pad_id_model = getattr(self.roberta.config, 'pad_token_id', None)
|
| 543 |
+
if pad_id_model is None:
|
| 544 |
+
pad_id_model = getattr(self.roberta.embeddings.word_embeddings, 'padding_idx', None)
|
| 545 |
+
if pad_id_model is None:
|
| 546 |
+
pad_id_model = 1 # Default pad token ID
|
| 547 |
+
input_mask = (input_tokens_tensor != pad_id_model).long()
|
| 548 |
+
token_type_ids = torch.zeros_like(input_tokens_tensor)
|
| 549 |
+
return input_tokens_tensor, output_tokens_tensor, token_type_ids, input_mask
|
| 550 |
+
# bart branch
|
| 551 |
+
pad_id_model = 1
|
| 552 |
+
input_mask = torch.ones_like(input_tokens_tensor)
|
| 553 |
+
token_type_ids = None
|
| 554 |
+
return input_tokens_tensor, output_tokens_tensor, token_type_ids, input_mask
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
__all__ = ["ViSoNormViSoBERTForMaskedLM"]
|