Commit ·
db6aa40
1
Parent(s): d802797
feat: update model
Browse files- .gitattributes +1 -0
- README.md +129 -105
- UPLOAD_INSTRUCTIONS.md +195 -0
- classification_report.txt +11 -11
- config.json +3 -48
- confusion_matrix.png +2 -2
- model.safetensors +1 -1
- test_results.json +3 -22
- training_curves.png +2 -2
- training_scripts/run_training_auto.sh +115 -0
- training_scripts/run_training_manual.sh +124 -0
- training_scripts/train_nfqa_model.py +870 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -14,11 +14,23 @@ This model classifies questions across **49 languages** into **8 categories** of
|
|
| 14 |
- **Categories**: 8 NFQA question types
|
| 15 |
- **Parameters**: ~278M parameters
|
| 16 |
- **Training Date**: January 2026
|
| 17 |
-
- **License**:
|
| 18 |
|
| 19 |
### Developers
|
| 20 |
|
| 21 |
-
Developed by
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
## Intended Use
|
| 24 |
|
|
@@ -38,52 +50,27 @@ Developed by [Your Name/Organization] for research in multilingual question unde
|
|
| 38 |
|
| 39 |
## Training Data
|
| 40 |
|
| 41 |
-
### Dataset
|
| 42 |
-
|
| 43 |
-
The model was trained on a carefully curated and balanced multilingual dataset:
|
| 44 |
-
|
| 45 |
-
- **Total Examples**: 62,932 question-label pairs
|
| 46 |
-
- **Source Data**:
|
| 47 |
-
- ~49,000 examples from the WebFAQ dataset (LLM-annotated with ensemble voting)
|
| 48 |
-
- ~14,000 examples generated and validated using LLMs to balance categories and languages
|
| 49 |
-
- **Data Split**:
|
| 50 |
-
- Training: 44,051 examples (70%)
|
| 51 |
-
- Validation: 6,294 examples (10%)
|
| 52 |
-
- Test: 12,587 examples (20%)
|
| 53 |
-
|
| 54 |
-
### Data Annotation & Balancing Process
|
| 55 |
-
|
| 56 |
-
The dataset was created through a rigorous multi-step process combining LLM annotation and validation:
|
| 57 |
-
|
| 58 |
-
**Phase 1: LLM Ensemble Annotation**
|
| 59 |
-
- The original ~49,000 WebFAQ question-answer pairs were annotated using an ensemble of three language models:
|
| 60 |
-
- **LLaMA 3.1**
|
| 61 |
-
- **Gemma 2**
|
| 62 |
-
- **Qwen 2.5**
|
| 63 |
|
| 64 |
-
**
|
| 65 |
-
- Only high-quality annotations were retained using ensemble voting with a minimum confidence threshold of **0.6**:
|
| 66 |
-
- **Confidence 1.0**: All 3 models agree on the same label (unanimous)
|
| 67 |
-
- **Confidence 0.67**: At least 2 out of 3 models agree (majority vote)
|
| 68 |
-
- Annotations below 0.6 confidence were excluded to ensure label reliability
|
| 69 |
|
| 70 |
-
**
|
| 71 |
-
-
|
| 72 |
-
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
**
|
| 75 |
-
-
|
| 76 |
-
-
|
| 77 |
|
| 78 |
-
**
|
| 79 |
-
-
|
| 80 |
-
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
- Combined high-quality annotated and validated data achieving balanced representation:
|
| 84 |
-
- Each language: ~1,000 questions
|
| 85 |
-
- Each category per language: ~125 questions
|
| 86 |
-
- Diverse coverage across all 49 languages and 8 categories
|
| 87 |
|
| 88 |
### Languages Supported
|
| 89 |
|
|
@@ -113,7 +100,23 @@ Questions seeking factual, objective answers (who, what, when, where).
|
|
| 113 |
- "When was the Eiffel Tower built?"
|
| 114 |
- "Who invented the telephone?"
|
| 115 |
|
| 116 |
-
### 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
How-to questions requiring step-by-step procedural answers.
|
| 118 |
|
| 119 |
**Examples:**
|
|
@@ -121,7 +124,7 @@ How-to questions requiring step-by-step procedural answers.
|
|
| 121 |
- "How to bake chocolate chip cookies?"
|
| 122 |
- "How can I install Python on Windows?"
|
| 123 |
|
| 124 |
-
###
|
| 125 |
Why/how questions seeking explanations or reasoning.
|
| 126 |
|
| 127 |
**Examples:**
|
|
@@ -129,22 +132,6 @@ Why/how questions seeking explanations or reasoning.
|
|
| 129 |
- "How does photosynthesis work?"
|
| 130 |
- "Why do birds migrate?"
|
| 131 |
|
| 132 |
-
### 5. EVIDENCE-BASED (Label 4)
|
| 133 |
-
Questions about definitions, features, or characteristics.
|
| 134 |
-
|
| 135 |
-
**Examples:**
|
| 136 |
-
- "What are the symptoms of flu?"
|
| 137 |
-
- "What features does this phone have?"
|
| 138 |
-
- "What is machine learning?"
|
| 139 |
-
|
| 140 |
-
### 6. COMPARISON (Label 5)
|
| 141 |
-
Questions comparing two or more options.
|
| 142 |
-
|
| 143 |
-
**Examples:**
|
| 144 |
-
- "iPhone vs Android: which is better?"
|
| 145 |
-
- "What's the difference between RNA and DNA?"
|
| 146 |
-
- "Compare electric and gas cars"
|
| 147 |
-
|
| 148 |
### 7. EXPERIENCE (Label 6)
|
| 149 |
Questions seeking personal experiences, recommendations, or advice.
|
| 150 |
|
|
@@ -153,48 +140,54 @@ Questions seeking personal experiences, recommendations, or advice.
|
|
| 153 |
- "Has anyone tried this restaurant?"
|
| 154 |
- "Which hotel would you recommend?"
|
| 155 |
|
| 156 |
-
### 8.
|
| 157 |
-
|
| 158 |
|
| 159 |
**Examples:**
|
| 160 |
-
- "
|
| 161 |
-
- "
|
| 162 |
-
- "
|
| 163 |
|
| 164 |
## Model Performance
|
| 165 |
|
| 166 |
-
### Test Set Results (
|
| 167 |
|
| 168 |
-
- **Overall Accuracy**: 88.
|
| 169 |
-
- **Macro-Average F1**:
|
| 170 |
-
- **Best Validation F1**:
|
| 171 |
|
| 172 |
### Per-Category Performance
|
| 173 |
|
| 174 |
| Category | Precision | Recall | F1-Score | Support |
|
| 175 |
|----------|-----------|--------|----------|---------|
|
| 176 |
-
| NOT-A-QUESTION | 0.
|
| 177 |
-
| FACTOID | 0.
|
| 178 |
-
|
|
| 179 |
-
|
|
| 180 |
-
|
|
| 181 |
-
|
|
| 182 |
-
| EXPERIENCE | 0.82 | 0.76 | 0.79 |
|
| 183 |
-
|
|
| 184 |
|
| 185 |
### Key Observations
|
| 186 |
|
| 187 |
-
- **Strongest Performance**:
|
| 188 |
-
- **Good Performance**:
|
| 189 |
-
- **Moderate Performance**:
|
| 190 |
-
- The model generalizes well across all 49 languages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
## Training Procedure
|
| 193 |
|
| 194 |
### Hardware
|
| 195 |
|
| 196 |
- Training Device: CUDA-enabled GPU (NVIDIA)
|
| 197 |
-
- Training Time:
|
| 198 |
|
| 199 |
### Hyperparameters
|
| 200 |
|
|
@@ -204,26 +197,47 @@ Hypothetical, opinion-based, or debatable questions.
|
|
| 204 |
"max_length": 128, # Maximum sequence length
|
| 205 |
"batch_size": 16, # Training batch size
|
| 206 |
"learning_rate": 2e-5, # AdamW learning rate
|
| 207 |
-
"num_epochs":
|
| 208 |
"warmup_steps": 500, # Linear warmup steps
|
| 209 |
"weight_decay": 0.01, # L2 regularization
|
|
|
|
| 210 |
"optimizer": "AdamW", # Optimizer
|
| 211 |
"scheduler": "linear_warmup", # Learning rate scheduler
|
| 212 |
"gradient_clipping": 1.0, # Max gradient norm
|
| 213 |
-
"test_size": 0.2, # 20% test split
|
| 214 |
-
"val_size": 0.1, # 10% validation split
|
| 215 |
"random_seed": 42 # Reproducibility
|
| 216 |
}
|
| 217 |
```
|
| 218 |
|
| 219 |
### Training Process
|
| 220 |
|
| 221 |
-
1. **Data Preparation**:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
2. **Preprocessing**: Tokenization using XLM-RoBERTa tokenizer (max length: 128 tokens)
|
|
|
|
| 223 |
3. **Training Strategy**: Supervised fine-tuning with stratified train/val/test splits
|
|
|
|
|
|
|
| 224 |
4. **Optimization**: AdamW optimizer with linear warmup and gradient clipping
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
## Usage
|
| 229 |
|
|
@@ -373,11 +387,11 @@ for r in results:
|
|
| 373 |
|
| 374 |
### Potential Biases
|
| 375 |
|
| 376 |
-
- **Annotation Bias**: Labels are based on LLM ensemble predictions (
|
| 377 |
-
- **Training Data Bias**: The model inherits biases from the
|
| 378 |
-
- **Language Representation**:
|
| 379 |
-
- **Category Distribution**:
|
| 380 |
-
- **
|
| 381 |
|
| 382 |
### Recommendations for Use
|
| 383 |
|
|
@@ -400,7 +414,7 @@ If you use this model in your research, please cite:
|
|
| 400 |
|
| 401 |
```bibtex
|
| 402 |
@misc{nfqa-multilingual-2026,
|
| 403 |
-
author = {
|
| 404 |
title = {NFQA Multilingual Question Classifier},
|
| 405 |
year = {2026},
|
| 406 |
publisher = {HuggingFace},
|
|
@@ -409,12 +423,23 @@ If you use this model in your research, please cite:
|
|
| 409 |
}
|
| 410 |
```
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
## Related Resources
|
| 413 |
|
| 414 |
-
- **
|
| 415 |
-
- **
|
| 416 |
-
- **
|
| 417 |
-
- **GitHub Repository**: [Link to your code repository]
|
| 418 |
|
| 419 |
## Model Card Contact
|
| 420 |
|
|
@@ -425,14 +450,13 @@ For questions, feedback, or issues:
|
|
| 425 |
|
| 426 |
## Acknowledgments
|
| 427 |
|
| 428 |
-
- Training
|
| 429 |
-
-
|
| 430 |
-
-
|
| 431 |
-
-
|
| 432 |
-
- Training infrastructure provided by University of Passau LLM inference server
|
| 433 |
|
| 434 |
---
|
| 435 |
|
| 436 |
**Model Version**: 1.0
|
| 437 |
-
**Last Updated**:
|
| 438 |
**Status**: Production Ready
|
|
|
|
| 14 |
- **Categories**: 8 NFQA question types
|
| 15 |
- **Parameters**: ~278M parameters
|
| 16 |
- **Training Date**: January 2026
|
| 17 |
+
- **License**: apache-2.0
|
| 18 |
|
| 19 |
### Developers
|
| 20 |
|
| 21 |
+
Developed by Ali Salman for research in multilingual question understanding and classification.
|
| 22 |
+
|
| 23 |
+
### Architecture
|
| 24 |
+
|
| 25 |
+
The model is based on XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach), a transformer-based multilingual encoder:
|
| 26 |
+
|
| 27 |
+
- **Base Architecture**: 12-layer transformer encoder
|
| 28 |
+
- **Hidden Size**: 768
|
| 29 |
+
- **Attention Heads**: 12
|
| 30 |
+
- **Parameters**: ~278M
|
| 31 |
+
- **Vocabulary Size**: 250,000 tokens (SentencePiece)
|
| 32 |
+
- **Pre-training**: Trained on 2.5TB of CommonCrawl data in 100 languages
|
| 33 |
+
- **Fine-tuning**: Classification head with dropout (0.2) for 8-class NFQA classification
|
| 34 |
|
| 35 |
## Intended Use
|
| 36 |
|
|
|
|
| 50 |
|
| 51 |
## Training Data
|
| 52 |
|
| 53 |
+
### Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
The model was trained on the **[NFQA Multilingual Dataset](https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset)**, a large-scale multilingual dataset for non-factoid question classification.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
**Dataset Composition**:
|
| 58 |
+
- **Training**: 33,602 examples (70%)
|
| 59 |
+
- **Validation**: 6,979 examples (15%)
|
| 60 |
+
- **Test**: 7,696 examples (15%)
|
| 61 |
+
- **Total**: 48,277 balanced examples
|
| 62 |
|
| 63 |
+
**Source Distribution**:
|
| 64 |
+
- 54% from WebFAQ dataset (annotated with LLM ensemble)
|
| 65 |
+
- 46% AI-generated to balance language-category combinations
|
| 66 |
|
| 67 |
+
**Key Features**:
|
| 68 |
+
- 392 unique (language, category) combinations
|
| 69 |
+
- Target of ~125 examples per combination
|
| 70 |
+
- Stratified sampling to ensure balanced representation
|
| 71 |
+
- Ensemble annotation using Llama 3.1, Gemma 2, and Qwen 2.5
|
| 72 |
|
| 73 |
+
For detailed information about dataset generation, annotation methodology, and data composition, please visit the [dataset page](https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
### Languages Supported
|
| 76 |
|
|
|
|
| 100 |
- "When was the Eiffel Tower built?"
|
| 101 |
- "Who invented the telephone?"
|
| 102 |
|
| 103 |
+
### 3. DEBATE (Label 2)
|
| 104 |
+
Hypothetical, opinion-based, or debatable questions.
|
| 105 |
+
|
| 106 |
+
**Examples:**
|
| 107 |
+
- "Is artificial intelligence dangerous?"
|
| 108 |
+
- "Should we colonize Mars?"
|
| 109 |
+
- "Is remote work better than office work?"
|
| 110 |
+
|
| 111 |
+
### 4. EVIDENCE-BASED (Label 3)
|
| 112 |
+
Questions about definitions, features, or characteristics.
|
| 113 |
+
|
| 114 |
+
**Examples:**
|
| 115 |
+
- "What are the symptoms of flu?"
|
| 116 |
+
- "What features does this phone have?"
|
| 117 |
+
- "What is machine learning?"
|
| 118 |
+
|
| 119 |
+
### 5. INSTRUCTION (Label 4)
|
| 120 |
How-to questions requiring step-by-step procedural answers.
|
| 121 |
|
| 122 |
**Examples:**
|
|
|
|
| 124 |
- "How to bake chocolate chip cookies?"
|
| 125 |
- "How can I install Python on Windows?"
|
| 126 |
|
| 127 |
+
### 6. REASON (Label 5)
|
| 128 |
Why/how questions seeking explanations or reasoning.
|
| 129 |
|
| 130 |
**Examples:**
|
|
|
|
| 132 |
- "How does photosynthesis work?"
|
| 133 |
- "Why do birds migrate?"
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
### 7. EXPERIENCE (Label 6)
|
| 136 |
Questions seeking personal experiences, recommendations, or advice.
|
| 137 |
|
|
|
|
| 140 |
- "Has anyone tried this restaurant?"
|
| 141 |
- "Which hotel would you recommend?"
|
| 142 |
|
| 143 |
+
### 8. COMPARISON (Label 7)
|
| 144 |
+
Questions comparing two or more options.
|
| 145 |
|
| 146 |
**Examples:**
|
| 147 |
+
- "iPhone vs Android: which is better?"
|
| 148 |
+
- "What's the difference between RNA and DNA?"
|
| 149 |
+
- "Compare electric and gas cars"
|
| 150 |
|
| 151 |
## Model Performance
|
| 152 |
|
| 153 |
+
### Test Set Results (7,696 examples)
|
| 154 |
|
| 155 |
+
- **Overall Accuracy**: 88.1%
|
| 156 |
+
- **Macro-Average F1**: 88.1%
|
| 157 |
+
- **Best Validation F1**: 88.1% (achieved at epoch 6)
|
| 158 |
|
| 159 |
### Per-Category Performance
|
| 160 |
|
| 161 |
| Category | Precision | Recall | F1-Score | Support |
|
| 162 |
|----------|-----------|--------|----------|---------|
|
| 163 |
+
| NOT-A-QUESTION | 0.96 | 0.92 | 0.94 | 950 |
|
| 164 |
+
| FACTOID | 0.84 | 0.79 | 0.81 | 980 |
|
| 165 |
+
| DEBATE | 0.90 | 0.95 | 0.92 | 916 |
|
| 166 |
+
| EVIDENCE-BASED | 0.86 | 0.92 | 0.89 | 950 |
|
| 167 |
+
| INSTRUCTION | 0.85 | 0.92 | 0.88 | 980 |
|
| 168 |
+
| REASON | 0.88 | 0.86 | 0.87 | 960 |
|
| 169 |
+
| EXPERIENCE | 0.82 | 0.76 | 0.79 | 980 |
|
| 170 |
+
| COMPARISON | 0.93 | 0.93 | 0.93 | 980 |
|
| 171 |
|
| 172 |
### Key Observations
|
| 173 |
|
| 174 |
+
- **Strongest Performance**: NOT-A-QUESTION, COMPARISON, and DEBATE categories (F1 ≥ 0.92)
|
| 175 |
+
- **Good Performance**: EVIDENCE-BASED, INSTRUCTION, and REASON categories (F1 ≥ 0.87)
|
| 176 |
+
- **Moderate Performance**: FACTOID and EXPERIENCE categories (F1 ~ 0.79-0.81)
|
| 177 |
+
- The model generalizes well across all 49 languages with balanced test set distribution
|
| 178 |
+
|
| 179 |
+
### Confusion Matrix
|
| 180 |
+
|
| 181 |
+

|
| 182 |
+
|
| 183 |
+
The confusion matrix shows the model's prediction patterns across all 8 categories. The diagonal elements represent correct classifications, while off-diagonal elements show misclassifications between categories.
|
| 184 |
|
| 185 |
## Training Procedure
|
| 186 |
|
| 187 |
### Hardware
|
| 188 |
|
| 189 |
- Training Device: CUDA-enabled GPU (NVIDIA)
|
| 190 |
+
- Training Time: 6 epochs to reach best performance
|
| 191 |
|
| 192 |
### Hyperparameters
|
| 193 |
|
|
|
|
| 197 |
"max_length": 128, # Maximum sequence length
|
| 198 |
"batch_size": 16, # Training batch size
|
| 199 |
"learning_rate": 2e-5, # AdamW learning rate
|
| 200 |
+
"num_epochs": 6, # Total epochs trained
|
| 201 |
"warmup_steps": 500, # Linear warmup steps
|
| 202 |
"weight_decay": 0.01, # L2 regularization
|
| 203 |
+
"dropout": 0.2, # Dropout probability
|
| 204 |
"optimizer": "AdamW", # Optimizer
|
| 205 |
"scheduler": "linear_warmup", # Learning rate scheduler
|
| 206 |
"gradient_clipping": 1.0, # Max gradient norm
|
|
|
|
|
|
|
| 207 |
"random_seed": 42 # Reproducibility
|
| 208 |
}
|
| 209 |
```
|
| 210 |
|
| 211 |
### Training Process
|
| 212 |
|
| 213 |
+
1. **Data Preparation**: Pre-split balanced dataset from [NFQA Multilingual Dataset](https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset)
|
| 214 |
+
- Training: 33,602 examples (70%)
|
| 215 |
+
- Validation: 6,979 examples (15%)
|
| 216 |
+
- Test: 7,696 examples (15%)
|
| 217 |
+
|
| 218 |
2. **Preprocessing**: Tokenization using XLM-RoBERTa tokenizer (max length: 128 tokens)
|
| 219 |
+
|
| 220 |
3. **Training Strategy**: Supervised fine-tuning with stratified train/val/test splits
|
| 221 |
+
- Stratified by (language, category) combinations to maintain balance
|
| 222 |
+
|
| 223 |
4. **Optimization**: AdamW optimizer with linear warmup and gradient clipping
|
| 224 |
+
- Total training steps: 12,606 (33,602 examples × 6 epochs ÷ 16 batch size)
|
| 225 |
+
- Warmup steps: 500
|
| 226 |
+
|
| 227 |
+
5. **Best Model Selection**: Model checkpoint with highest validation F1 score (epoch 6)
|
| 228 |
+
|
| 229 |
+
6. **Evaluation**: Comprehensive testing on held-out test set with per-category and per-language analysis
|
| 230 |
+
|
| 231 |
+
### Training Curves
|
| 232 |
+
|
| 233 |
+

|
| 234 |
+
|
| 235 |
+
The training curves show the model's learning progress across 6 epochs:
|
| 236 |
+
- **Left panel**: Training and validation loss over time
|
| 237 |
+
- **Middle panel**: Training and validation accuracy progression
|
| 238 |
+
- **Right panel**: Validation F1 score (macro average) with best checkpoint marked
|
| 239 |
+
|
| 240 |
+
The model converged quickly, reaching optimal performance at epoch 6 with minimal overfitting.
|
| 241 |
|
| 242 |
## Usage
|
| 243 |
|
|
|
|
| 387 |
|
| 388 |
### Potential Biases
|
| 389 |
|
| 390 |
+
- **Annotation Bias**: Labels are based on LLM ensemble predictions (Llama 3.1, Gemma 2, Qwen 2.5) rather than human annotations, which may introduce systematic biases from these underlying models
|
| 391 |
+
- **Training Data Bias**: The model inherits biases from the WebFAQ dataset and AI-generated examples
|
| 392 |
+
- **Language Representation**: While the dataset includes 49 languages, some language families may have different performance characteristics
|
| 393 |
+
- **Category Distribution**: The balanced dataset has similar representation across categories (~980 examples each in test set), which may differ from real-world distributions
|
| 394 |
+
- **Domain Specificity**: Trained primarily on FAQ-style and general questions; performance may vary on domain-specific questions
|
| 395 |
|
| 396 |
### Recommendations for Use
|
| 397 |
|
|
|
|
| 414 |
|
| 415 |
```bibtex
|
| 416 |
@misc{nfqa-multilingual-2026,
|
| 417 |
+
author = {Ali Salman},
|
| 418 |
title = {NFQA Multilingual Question Classifier},
|
| 419 |
year = {2026},
|
| 420 |
publisher = {HuggingFace},
|
|
|
|
| 423 |
}
|
| 424 |
```
|
| 425 |
|
| 426 |
+
Please also cite the training dataset:
|
| 427 |
+
|
| 428 |
+
```bibtex
|
| 429 |
+
@dataset{nfqa_multilingual_dataset_2026,
|
| 430 |
+
author = {Ali Salman},
|
| 431 |
+
title = {NFQA Multilingual Dataset: A Large-Scale Dataset for Non-Factoid Question Classification},
|
| 432 |
+
year = {2026},
|
| 433 |
+
publisher = {Hugging Face},
|
| 434 |
+
howpublished = {\url{https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset}}
|
| 435 |
+
}
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
## Related Resources
|
| 439 |
|
| 440 |
+
- **Training Dataset**: [NFQA Multilingual Dataset](https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset)
|
| 441 |
+
- **WebFAQ Dataset**: [PaDaS-Lab/webfaq](https://huggingface.co/datasets/PaDaS-Lab/webfaq)
|
| 442 |
+
- **XLM-RoBERTa**: [xlm-roberta-base](https://huggingface.co/xlm-roberta-base)
|
|
|
|
| 443 |
|
| 444 |
## Model Card Contact
|
| 445 |
|
|
|
|
| 450 |
|
| 451 |
## Acknowledgments
|
| 452 |
|
| 453 |
+
- Training dataset: [NFQA Multilingual Dataset](https://huggingface.co/datasets/AliSalman29/nfqa-multilingual-dataset)
|
| 454 |
+
- Source data: [WebFAQ Dataset](https://huggingface.co/datasets/PaDaS-Lab/webfaq)
|
| 455 |
+
- Built on the [XLM-RoBERTa](https://huggingface.co/xlm-roberta-base) foundation model by Meta AI
|
| 456 |
+
- Annotation and generation using Llama 3.1, Gemma 2, and Qwen 2.5
|
|
|
|
| 457 |
|
| 458 |
---
|
| 459 |
|
| 460 |
**Model Version**: 1.0
|
| 461 |
+
**Last Updated**: February 2026
|
| 462 |
**Status**: Production Ready
|
UPLOAD_INSTRUCTIONS.md
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Instructions to Upload to Hugging Face
|
| 2 |
+
|
| 3 |
+
This repository is ready to be pushed to Hugging Face Model Hub!
|
| 4 |
+
|
| 5 |
+
## Quick Setup (5 minutes)
|
| 6 |
+
|
| 7 |
+
### Step 1: Create Hugging Face Repository
|
| 8 |
+
|
| 9 |
+
1. Go to https://huggingface.co/new
|
| 10 |
+
2. Fill in:
|
| 11 |
+
- **Model name**: `nfqa-multilingual-classifier`
|
| 12 |
+
- **License**: Apache 2.0 (recommended) or your choice
|
| 13 |
+
- **Visibility**: Public (or Private if you prefer)
|
| 14 |
+
3. Click **"Create model"**
|
| 15 |
+
4. **Important**: Copy your repository URL from the page
|
| 16 |
+
|
| 17 |
+
### Step 2: Get Your Access Token
|
| 18 |
+
|
| 19 |
+
1. Go to https://huggingface.co/settings/tokens
|
| 20 |
+
2. Click **"New token"**
|
| 21 |
+
3. Name: `model-upload`
|
| 22 |
+
4. Type: **Write** (important!)
|
| 23 |
+
5. Click **"Generate token"**
|
| 24 |
+
6. **Copy the token** (you won't see it again)
|
| 25 |
+
|
| 26 |
+
### Step 3: Connect This Repository
|
| 27 |
+
|
| 28 |
+
Replace `YOUR_USERNAME` with your actual Hugging Face username:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
cd /Users/alisalman/thesis/nfqa-multilingual-classifier
|
| 32 |
+
|
| 33 |
+
# Add Hugging Face as remote
|
| 34 |
+
git remote add origin https://huggingface.co/YOUR_USERNAME/nfqa-multilingual-classifier
|
| 35 |
+
|
| 36 |
+
# Configure git to use your HF credentials
|
| 37 |
+
git config credential.helper store
|
| 38 |
+
|
| 39 |
+
# Push to Hugging Face (you'll be prompted for username and token)
|
| 40 |
+
git push -u origin master
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
When prompted:
|
| 44 |
+
- **Username**: Your Hugging Face username
|
| 45 |
+
- **Password**: Paste your access token (not your password!)
|
| 46 |
+
|
| 47 |
+
### Step 4: Verify Upload
|
| 48 |
+
|
| 49 |
+
1. Go to `https://huggingface.co/YOUR_USERNAME/nfqa-multilingual-classifier`
|
| 50 |
+
2. You should see:
|
| 51 |
+
- ✅ All model files (11 files)
|
| 52 |
+
- ✅ README with full documentation
|
| 53 |
+
- ✅ Training visualizations (confusion matrix, training curves)
|
| 54 |
+
- ✅ Model card with usage examples
|
| 55 |
+
3. Test the **Inference API** widget with a question
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Alternative: Use Hugging Face CLI
|
| 60 |
+
|
| 61 |
+
If you prefer using the CLI:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
# Install if not already installed
|
| 65 |
+
pip install --upgrade huggingface_hub
|
| 66 |
+
|
| 67 |
+
# Login
|
| 68 |
+
huggingface-cli login
|
| 69 |
+
# Paste your token when prompted
|
| 70 |
+
|
| 71 |
+
# Create repository
|
| 72 |
+
huggingface-cli repo create nfqa-multilingual-classifier --type model
|
| 73 |
+
|
| 74 |
+
# Upload
|
| 75 |
+
cd /Users/alisalman/thesis/nfqa-multilingual-classifier
|
| 76 |
+
huggingface-cli upload nfqa-multilingual-classifier . --repo-type model
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## What's Included
|
| 82 |
+
|
| 83 |
+
This repository contains:
|
| 84 |
+
|
| 85 |
+
✅ **Model Files** (1.1 GB total):
|
| 86 |
+
- `model.safetensors` - Model weights
|
| 87 |
+
- `config.json` - Model configuration
|
| 88 |
+
- `tokenizer.json` - Tokenizer
|
| 89 |
+
- `tokenizer_config.json` - Tokenizer settings
|
| 90 |
+
- `sentencepiece.bpe.model` - Vocabulary
|
| 91 |
+
- `special_tokens_map.json` - Special tokens
|
| 92 |
+
|
| 93 |
+
✅ **Documentation**:
|
| 94 |
+
- `README.md` - Comprehensive model card
|
| 95 |
+
- `classification_report.txt` - Per-category performance
|
| 96 |
+
- `test_results.json` - Detailed evaluation metrics
|
| 97 |
+
|
| 98 |
+
✅ **Visualizations**:
|
| 99 |
+
- `confusion_matrix.png` - Test set confusion matrix
|
| 100 |
+
- `training_curves.png` - Training/validation curves
|
| 101 |
+
|
| 102 |
+
✅ **Git Configuration**:
|
| 103 |
+
- `.gitattributes` - LFS tracking for large files
|
| 104 |
+
- `.gitignore` - Ignore patterns
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Before You Push
|
| 109 |
+
|
| 110 |
+
### Update README Placeholders
|
| 111 |
+
|
| 112 |
+
Edit [README.md](README.md) and replace:
|
| 113 |
+
- `[Your Name/Organization]` → Your actual name
|
| 114 |
+
- `[Specify your license]` → Your license choice
|
| 115 |
+
- `your-username/nfqa-multilingual-classifier` → Your actual repo URL
|
| 116 |
+
- `[Your email]` → Your contact email
|
| 117 |
+
- `[Your repository]` → Your GitHub repo (if any)
|
| 118 |
+
|
| 119 |
+
You can edit directly on Hugging Face after uploading, or do it now:
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
nano README.md
|
| 123 |
+
# or use your preferred editor
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Troubleshooting
|
| 129 |
+
|
| 130 |
+
### Error: "Repository not found"
|
| 131 |
+
- Make sure you created the repository on huggingface.co first
|
| 132 |
+
- Check that the username in the URL matches your HF username
|
| 133 |
+
|
| 134 |
+
### Error: "Authentication failed"
|
| 135 |
+
- Make sure you're using your **token** as password, not your account password
|
| 136 |
+
- Verify the token has **Write** permissions
|
| 137 |
+
- Try `git credential reject` to clear cached credentials
|
| 138 |
+
|
| 139 |
+
### Error: "Large file not properly tracked"
|
| 140 |
+
- LFS is already configured in this repo
|
| 141 |
+
- Just push normally, git-lfs will handle large files automatically
|
| 142 |
+
|
| 143 |
+
### Upload is very slow
|
| 144 |
+
- The model is ~1.1 GB, this is normal
|
| 145 |
+
- It may take 5-15 minutes depending on your internet speed
|
| 146 |
+
- Git LFS uploads large files efficiently
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## After Upload
|
| 151 |
+
|
| 152 |
+
1. **Test the model**:
|
| 153 |
+
```python
|
| 154 |
+
from transformers import pipeline
|
| 155 |
+
|
| 156 |
+
classifier = pipeline("text-classification",
|
| 157 |
+
model="YOUR_USERNAME/nfqa-multilingual-classifier")
|
| 158 |
+
result = classifier("What is the capital of France?")
|
| 159 |
+
print(result)
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
2. **Add widget examples** in the README YAML front matter (optional)
|
| 163 |
+
|
| 164 |
+
3. **Share your model** on social media, papers, etc.
|
| 165 |
+
|
| 166 |
+
4. **Monitor usage** at `https://huggingface.co/YOUR_USERNAME/nfqa-multilingual-classifier/tree/main`
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## Quick Reference
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
# View repository status
|
| 174 |
+
cd /Users/alisalman/thesis/nfqa-multilingual-classifier
|
| 175 |
+
git status
|
| 176 |
+
|
| 177 |
+
# View commit history
|
| 178 |
+
git log --oneline
|
| 179 |
+
|
| 180 |
+
# Check remote URL
|
| 181 |
+
git remote -v
|
| 182 |
+
|
| 183 |
+
# Push updates (after making changes)
|
| 184 |
+
git add .
|
| 185 |
+
git commit -m "Update model card"
|
| 186 |
+
git push
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
**Need help?**
|
| 192 |
+
- Hugging Face Docs: https://huggingface.co/docs/hub
|
| 193 |
+
- Git LFS Guide: https://git-lfs.github.com/
|
| 194 |
+
|
| 195 |
+
**Ready to push?** Follow Step 3 above!
|
classification_report.txt
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
precision recall f1-score support
|
| 2 |
|
| 3 |
-
NOT-A-QUESTION 0.
|
| 4 |
-
FACTOID 0.
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
EXPERIENCE 0.82 0.76 0.79
|
| 10 |
-
|
| 11 |
|
| 12 |
-
accuracy 0.
|
| 13 |
-
macro avg 0.
|
| 14 |
-
weighted avg 0.
|
|
|
|
| 1 |
precision recall f1-score support
|
| 2 |
|
| 3 |
+
NOT-A-QUESTION 0.96 0.92 0.94 950
|
| 4 |
+
FACTOID 0.84 0.79 0.81 980
|
| 5 |
+
DEBATE 0.90 0.95 0.92 916
|
| 6 |
+
EVIDENCE-BASED 0.86 0.92 0.89 950
|
| 7 |
+
INSTRUCTION 0.85 0.92 0.88 980
|
| 8 |
+
REASON 0.88 0.86 0.87 960
|
| 9 |
+
EXPERIENCE 0.82 0.76 0.79 980
|
| 10 |
+
COMPARISON 0.93 0.93 0.93 980
|
| 11 |
|
| 12 |
+
accuracy 0.88 7696
|
| 13 |
+
macro avg 0.88 0.88 0.88 7696
|
| 14 |
+
weighted avg 0.88 0.88 0.88 7696
|
config.json
CHANGED
|
@@ -1,48 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.1,
|
| 6 |
-
"bos_token_id": 0,
|
| 7 |
-
"classifier_dropout": null,
|
| 8 |
-
"eos_token_id": 2,
|
| 9 |
-
"hidden_act": "gelu",
|
| 10 |
-
"hidden_dropout_prob": 0.1,
|
| 11 |
-
"hidden_size": 768,
|
| 12 |
-
"id2label": {
|
| 13 |
-
"0": "NOT-A-QUESTION",
|
| 14 |
-
"1": "FACTOID",
|
| 15 |
-
"2": "INSTRUCTION",
|
| 16 |
-
"3": "REASON",
|
| 17 |
-
"4": "EVIDENCE-BASED",
|
| 18 |
-
"5": "COMPARISON",
|
| 19 |
-
"6": "EXPERIENCE",
|
| 20 |
-
"7": "DEBATE"
|
| 21 |
-
},
|
| 22 |
-
"initializer_range": 0.02,
|
| 23 |
-
"intermediate_size": 3072,
|
| 24 |
-
"label2id": {
|
| 25 |
-
"COMPARISON": 5,
|
| 26 |
-
"DEBATE": 7,
|
| 27 |
-
"EVIDENCE-BASED": 4,
|
| 28 |
-
"EXPERIENCE": 6,
|
| 29 |
-
"FACTOID": 1,
|
| 30 |
-
"INSTRUCTION": 2,
|
| 31 |
-
"NOT-A-QUESTION": 0,
|
| 32 |
-
"REASON": 3
|
| 33 |
-
},
|
| 34 |
-
"layer_norm_eps": 1e-05,
|
| 35 |
-
"max_position_embeddings": 514,
|
| 36 |
-
"model_type": "xlm-roberta",
|
| 37 |
-
"num_attention_heads": 12,
|
| 38 |
-
"num_hidden_layers": 12,
|
| 39 |
-
"output_past": true,
|
| 40 |
-
"pad_token_id": 1,
|
| 41 |
-
"position_embedding_type": "absolute",
|
| 42 |
-
"problem_type": "single_label_classification",
|
| 43 |
-
"torch_dtype": "float32",
|
| 44 |
-
"transformers_version": "4.50.3",
|
| 45 |
-
"type_vocab_size": 1,
|
| 46 |
-
"use_cache": true,
|
| 47 |
-
"vocab_size": 250002
|
| 48 |
-
}
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d64b32cf7198deee34a207a62d0681ea08b0b2ae51b5d011324791e5b24c6a9a
|
| 3 |
+
size 1118
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
confusion_matrix.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1112223464
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33a25f9cc0e6e82ac88d37fb2ec3bfb4e61c9751e5db98d13f04692a5ab2f734
|
| 3 |
size 1112223464
|
test_results.json
CHANGED
|
@@ -1,22 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
"test_f1_macro": 0.8672321792444992,
|
| 5 |
-
"best_epoch": 27,
|
| 6 |
-
"best_val_f1": 0.8676620754981998,
|
| 7 |
-
"num_train_examples": 44051,
|
| 8 |
-
"num_val_examples": 6294,
|
| 9 |
-
"num_test_examples": 12587,
|
| 10 |
-
"config": {
|
| 11 |
-
"model_name": "xlm-roberta-base",
|
| 12 |
-
"max_length": 128,
|
| 13 |
-
"batch_size": 16,
|
| 14 |
-
"learning_rate": 2e-05,
|
| 15 |
-
"num_epochs": 30,
|
| 16 |
-
"warmup_steps": 500,
|
| 17 |
-
"weight_decay": 0.01,
|
| 18 |
-
"test_size": 0.2,
|
| 19 |
-
"val_size": 0.1
|
| 20 |
-
},
|
| 21 |
-
"timestamp": "2026-01-16T19:09:44.473503"
|
| 22 |
-
}
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b90b48f6007b8ab4fdc46e83e9dcf2561802f5946e48bd665fde14a9ba3fa7d
|
| 3 |
+
size 778
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_curves.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
training_scripts/run_training_auto.sh
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#
|
| 3 |
+
# Train NFQA Model with Automatic Data Splitting
|
| 4 |
+
#
|
| 5 |
+
# This script trains the NFQA classification model using a single combined
|
| 6 |
+
# dataset that will be automatically split into train/val/test sets.
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# bash run_training_auto.sh
|
| 10 |
+
#
|
| 11 |
+
# Or with custom parameters:
|
| 12 |
+
# bash run_training_auto.sh --epochs 15 --batch-size 32
|
| 13 |
+
#
|
| 14 |
+
|
| 15 |
+
set -e # Exit on error
|
| 16 |
+
|
| 17 |
+
# Default paths
|
| 18 |
+
INPUT_FILE="../output/webfaq_nfqa_combined_highquality.jsonl"
|
| 19 |
+
OUTPUT_DIR="../output/training/nfqa_model_auto"
|
| 20 |
+
|
| 21 |
+
# Default training parameters
|
| 22 |
+
MODEL_NAME="xlm-roberta-base"
|
| 23 |
+
EPOCHS=6
|
| 24 |
+
BATCH_SIZE=16
|
| 25 |
+
LEARNING_RATE=2e-5
|
| 26 |
+
MAX_LENGTH=128
|
| 27 |
+
WARMUP_STEPS=500
|
| 28 |
+
WEIGHT_DECAY=0.1
|
| 29 |
+
DROPOUT=0.2
|
| 30 |
+
TEST_SIZE=0.2
|
| 31 |
+
VAL_SIZE=0.1
|
| 32 |
+
|
| 33 |
+
echo "================================================================================"
|
| 34 |
+
echo "NFQA Model Training - Automatic Split Mode"
|
| 35 |
+
echo "================================================================================"
|
| 36 |
+
echo ""
|
| 37 |
+
echo "Training Configuration:"
|
| 38 |
+
echo " Input file: $INPUT_FILE"
|
| 39 |
+
echo " Output directory: $OUTPUT_DIR"
|
| 40 |
+
echo " Model: $MODEL_NAME"
|
| 41 |
+
echo " Epochs: $EPOCHS"
|
| 42 |
+
echo " Batch size: $BATCH_SIZE"
|
| 43 |
+
echo " Learning rate: $LEARNING_RATE"
|
| 44 |
+
echo " Max length: $MAX_LENGTH"
|
| 45 |
+
echo " Weight decay: $WEIGHT_DECAY"
|
| 46 |
+
echo " Dropout: $DROPOUT"
|
| 47 |
+
echo " Test split: $TEST_SIZE (20%)"
|
| 48 |
+
echo " Val split: $VAL_SIZE (10%)"
|
| 49 |
+
echo ""
|
| 50 |
+
echo "================================================================================"
|
| 51 |
+
echo ""
|
| 52 |
+
|
| 53 |
+
# Check if input file exists
|
| 54 |
+
if [ ! -f "$INPUT_FILE" ]; then
|
| 55 |
+
echo "❌ Error: Input file not found: $INPUT_FILE"
|
| 56 |
+
echo ""
|
| 57 |
+
echo "Please ensure the combined dataset exists."
|
| 58 |
+
echo "You can create it by running:"
|
| 59 |
+
echo " cd ../annotator"
|
| 60 |
+
echo " python combine_datasets.py"
|
| 61 |
+
exit 1
|
| 62 |
+
fi
|
| 63 |
+
|
| 64 |
+
# Create output directory
|
| 65 |
+
mkdir -p "$OUTPUT_DIR"
|
| 66 |
+
|
| 67 |
+
# Run training
|
| 68 |
+
python train_nfqa_model.py \
|
| 69 |
+
--input "$INPUT_FILE" \
|
| 70 |
+
--output-dir "$OUTPUT_DIR" \
|
| 71 |
+
--model-name "$MODEL_NAME" \
|
| 72 |
+
--epochs "$EPOCHS" \
|
| 73 |
+
--batch-size "$BATCH_SIZE" \
|
| 74 |
+
--learning-rate "$LEARNING_RATE" \
|
| 75 |
+
--max-length "$MAX_LENGTH" \
|
| 76 |
+
--warmup-steps "$WARMUP_STEPS" \
|
| 77 |
+
--weight-decay "$WEIGHT_DECAY" \
|
| 78 |
+
--dropout "$DROPOUT" \
|
| 79 |
+
--test-size "$TEST_SIZE" \
|
| 80 |
+
--val-size "$VAL_SIZE" \
|
| 81 |
+
"$@" # Pass any additional arguments from command line
|
| 82 |
+
|
| 83 |
+
# Check if training was successful
|
| 84 |
+
if [ $? -eq 0 ]; then
|
| 85 |
+
echo ""
|
| 86 |
+
echo "================================================================================"
|
| 87 |
+
echo "✅ Training completed successfully!"
|
| 88 |
+
echo "================================================================================"
|
| 89 |
+
echo ""
|
| 90 |
+
echo "Model saved to: $OUTPUT_DIR"
|
| 91 |
+
echo ""
|
| 92 |
+
echo "Generated files:"
|
| 93 |
+
echo " - best_model/ (best checkpoint based on validation F1)"
|
| 94 |
+
echo " - final_model/ (final epoch checkpoint)"
|
| 95 |
+
echo " - training_history.json (training metrics)"
|
| 96 |
+
echo " - training_curves.png (loss/accuracy/F1 plots)"
|
| 97 |
+
echo " - test_results.json (final test metrics)"
|
| 98 |
+
echo " - classification_report.txt (per-category performance)"
|
| 99 |
+
echo " - confusion_matrix.png (confusion matrix visualization)"
|
| 100 |
+
echo ""
|
| 101 |
+
echo "Next steps:"
|
| 102 |
+
echo " 1. Review training curves: $OUTPUT_DIR/training_curves.png"
|
| 103 |
+
echo " 2. Check test results: $OUTPUT_DIR/test_results.json"
|
| 104 |
+
echo " 3. Analyze confusion matrix: $OUTPUT_DIR/confusion_matrix.png"
|
| 105 |
+
echo " 4. Deploy model from: $OUTPUT_DIR/best_model/"
|
| 106 |
+
echo ""
|
| 107 |
+
else
|
| 108 |
+
echo ""
|
| 109 |
+
echo "================================================================================"
|
| 110 |
+
echo "❌ Training failed!"
|
| 111 |
+
echo "================================================================================"
|
| 112 |
+
echo ""
|
| 113 |
+
echo "Please check the error messages above and try again."
|
| 114 |
+
exit 1
|
| 115 |
+
fi
|
training_scripts/run_training_manual.sh
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#
|
| 3 |
+
# Train NFQA Model with Pre-Split Datasets
|
| 4 |
+
#
|
| 5 |
+
# This script trains the NFQA classification model using manually split
|
| 6 |
+
# train/validation/test datasets for balanced training.
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# bash run_training_manual.sh
|
| 10 |
+
#
|
| 11 |
+
# Or with custom parameters:
|
| 12 |
+
# bash run_training_manual.sh --epochs 15 --batch-size 32
|
| 13 |
+
#
|
| 14 |
+
|
| 15 |
+
set -e # Exit on error
|
| 16 |
+
|
| 17 |
+
# Default paths
|
| 18 |
+
TRAIN_FILE="../output/train_balanced.jsonl"
|
| 19 |
+
VAL_FILE="../output/val_balanced.jsonl"
|
| 20 |
+
TEST_FILE="../output/test_balanced.jsonl"
|
| 21 |
+
OUTPUT_DIR="../output/training/nfqa_model_balanced"
|
| 22 |
+
|
| 23 |
+
# Default training parameters
|
| 24 |
+
MODEL_NAME="xlm-roberta-base"
|
| 25 |
+
EPOCHS=6
|
| 26 |
+
BATCH_SIZE=16
|
| 27 |
+
LEARNING_RATE=2e-5
|
| 28 |
+
MAX_LENGTH=128
|
| 29 |
+
WARMUP_STEPS=500
|
| 30 |
+
WEIGHT_DECAY=0.1
|
| 31 |
+
DROPOUT=0.2
|
| 32 |
+
|
| 33 |
+
echo "================================================================================"
|
| 34 |
+
echo "NFQA Model Training - Manual Split Mode"
|
| 35 |
+
echo "================================================================================"
|
| 36 |
+
echo ""
|
| 37 |
+
echo "Training Configuration:"
|
| 38 |
+
echo " Train file: $TRAIN_FILE"
|
| 39 |
+
echo " Validation file: $VAL_FILE"
|
| 40 |
+
echo " Test file: $TEST_FILE"
|
| 41 |
+
echo " Output directory: $OUTPUT_DIR"
|
| 42 |
+
echo " Model: $MODEL_NAME"
|
| 43 |
+
echo " Epochs: $EPOCHS"
|
| 44 |
+
echo " Batch size: $BATCH_SIZE"
|
| 45 |
+
echo " Learning rate: $LEARNING_RATE"
|
| 46 |
+
echo " Max length: $MAX_LENGTH"
|
| 47 |
+
echo " Weight decay: $WEIGHT_DECAY"
|
| 48 |
+
echo " Dropout: $DROPOUT"
|
| 49 |
+
echo ""
|
| 50 |
+
echo "================================================================================"
|
| 51 |
+
echo ""
|
| 52 |
+
|
| 53 |
+
# Check if required files exist
|
| 54 |
+
if [ ! -f "$TRAIN_FILE" ]; then
|
| 55 |
+
echo "❌ Error: Training file not found: $TRAIN_FILE"
|
| 56 |
+
echo ""
|
| 57 |
+
echo "Please run the data splitting script first:"
|
| 58 |
+
echo " cd ../cleaning"
|
| 59 |
+
echo " python split_train_test_val.py --input ../output/webfaq_nfqa_combined_highquality.jsonl"
|
| 60 |
+
exit 1
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
if [ ! -f "$VAL_FILE" ]; then
|
| 64 |
+
echo "❌ Error: Validation file not found: $VAL_FILE"
|
| 65 |
+
exit 1
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [ ! -f "$TEST_FILE" ]; then
|
| 69 |
+
echo "❌ Error: Test file not found: $TEST_FILE"
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
|
| 73 |
+
# Create output directory
|
| 74 |
+
mkdir -p "$OUTPUT_DIR"
|
| 75 |
+
|
| 76 |
+
# Run training
|
| 77 |
+
python train_nfqa_model.py \
|
| 78 |
+
--train "$TRAIN_FILE" \
|
| 79 |
+
--val "$VAL_FILE" \
|
| 80 |
+
--test "$TEST_FILE" \
|
| 81 |
+
--output-dir "$OUTPUT_DIR" \
|
| 82 |
+
--model-name "$MODEL_NAME" \
|
| 83 |
+
--epochs "$EPOCHS" \
|
| 84 |
+
--batch-size "$BATCH_SIZE" \
|
| 85 |
+
--learning-rate "$LEARNING_RATE" \
|
| 86 |
+
--max-length "$MAX_LENGTH" \
|
| 87 |
+
--warmup-steps "$WARMUP_STEPS" \
|
| 88 |
+
--weight-decay "$WEIGHT_DECAY" \
|
| 89 |
+
--dropout "$DROPOUT" \
|
| 90 |
+
"$@" # Pass any additional arguments from command line
|
| 91 |
+
|
| 92 |
+
# Check if training was successful
|
| 93 |
+
if [ $? -eq 0 ]; then
|
| 94 |
+
echo ""
|
| 95 |
+
echo "================================================================================"
|
| 96 |
+
echo "✅ Training completed successfully!"
|
| 97 |
+
echo "================================================================================"
|
| 98 |
+
echo ""
|
| 99 |
+
echo "Model saved to: $OUTPUT_DIR"
|
| 100 |
+
echo ""
|
| 101 |
+
echo "Generated files:"
|
| 102 |
+
echo " - best_model/ (best checkpoint based on validation F1)"
|
| 103 |
+
echo " - final_model/ (final epoch checkpoint)"
|
| 104 |
+
echo " - training_history.json (training metrics)"
|
| 105 |
+
echo " - training_curves.png (loss/accuracy/F1 plots)"
|
| 106 |
+
echo " - test_results.json (final test metrics)"
|
| 107 |
+
echo " - classification_report.txt (per-category performance)"
|
| 108 |
+
echo " - confusion_matrix.png (confusion matrix visualization)"
|
| 109 |
+
echo ""
|
| 110 |
+
echo "Next steps:"
|
| 111 |
+
echo " 1. Review training curves: $OUTPUT_DIR/training_curves.png"
|
| 112 |
+
echo " 2. Check test results: $OUTPUT_DIR/test_results.json"
|
| 113 |
+
echo " 3. Analyze confusion matrix: $OUTPUT_DIR/confusion_matrix.png"
|
| 114 |
+
echo " 4. Deploy model from: $OUTPUT_DIR/best_model/"
|
| 115 |
+
echo ""
|
| 116 |
+
else
|
| 117 |
+
echo ""
|
| 118 |
+
echo "================================================================================"
|
| 119 |
+
echo "❌ Training failed!"
|
| 120 |
+
echo "================================================================================"
|
| 121 |
+
echo ""
|
| 122 |
+
echo "Please check the error messages above and try again."
|
| 123 |
+
exit 1
|
| 124 |
+
fi
|
training_scripts/train_nfqa_model.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train NFQA Classification Model from Scratch
|
| 4 |
+
|
| 5 |
+
Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data.
|
| 6 |
+
|
| 7 |
+
Usage (single file with automatic splitting):
|
| 8 |
+
python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10
|
| 9 |
+
|
| 10 |
+
Usage (pre-split files):
|
| 11 |
+
python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10
|
| 12 |
+
|
| 13 |
+
Author: Ali
|
| 14 |
+
Date: December 2024
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import json
|
| 21 |
+
import argparse
|
| 22 |
+
import os
|
| 23 |
+
from collections import Counter
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from torch.utils.data import Dataset, DataLoader
|
| 26 |
+
from torch.optim import AdamW
|
| 27 |
+
from transformers import (
|
| 28 |
+
AutoTokenizer,
|
| 29 |
+
AutoModelForSequenceClassification,
|
| 30 |
+
get_linear_schedule_with_warmup
|
| 31 |
+
)
|
| 32 |
+
from sklearn.model_selection import train_test_split
|
| 33 |
+
from sklearn.metrics import (
|
| 34 |
+
classification_report,
|
| 35 |
+
confusion_matrix,
|
| 36 |
+
accuracy_score,
|
| 37 |
+
f1_score
|
| 38 |
+
)
|
| 39 |
+
import matplotlib
|
| 40 |
+
matplotlib.use('Agg') # Non-interactive backend for server
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import seaborn as sns
|
| 43 |
+
from tqdm import tqdm
|
| 44 |
+
|
| 45 |
+
# Set random seed
|
| 46 |
+
RANDOM_SEED = 42
|
| 47 |
+
np.random.seed(RANDOM_SEED)
|
| 48 |
+
torch.manual_seed(RANDOM_SEED)
|
| 49 |
+
|
| 50 |
+
NFQA_CATEGORIES = [
|
| 51 |
+
'NOT-A-QUESTION',
|
| 52 |
+
'FACTOID',
|
| 53 |
+
'DEBATE',
|
| 54 |
+
'EVIDENCE-BASED',
|
| 55 |
+
'INSTRUCTION',
|
| 56 |
+
'REASON',
|
| 57 |
+
'EXPERIENCE',
|
| 58 |
+
'COMPARISON'
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# Label mappings
|
| 62 |
+
LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)}
|
| 63 |
+
ID2LABEL = {idx: label for label, idx in LABEL2ID.items()}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class NFQADataset(Dataset):
|
| 67 |
+
"""Custom dataset for NFQA classification"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, questions, labels, tokenizer, max_length=128):
|
| 70 |
+
self.questions = questions
|
| 71 |
+
self.labels = labels
|
| 72 |
+
self.tokenizer = tokenizer
|
| 73 |
+
self.max_length = max_length
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.questions)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, idx):
|
| 79 |
+
question = str(self.questions[idx])
|
| 80 |
+
label = int(self.labels[idx])
|
| 81 |
+
|
| 82 |
+
# Tokenize
|
| 83 |
+
encoding = self.tokenizer(
|
| 84 |
+
question,
|
| 85 |
+
add_special_tokens=True,
|
| 86 |
+
max_length=self.max_length,
|
| 87 |
+
padding='max_length',
|
| 88 |
+
truncation=True,
|
| 89 |
+
return_attention_mask=True,
|
| 90 |
+
return_tensors='pt'
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 95 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 96 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def train_epoch(model, train_loader, optimizer, scheduler, device):
|
| 101 |
+
"""Train for one epoch"""
|
| 102 |
+
model.train()
|
| 103 |
+
total_loss = 0
|
| 104 |
+
predictions = []
|
| 105 |
+
true_labels = []
|
| 106 |
+
|
| 107 |
+
progress_bar = tqdm(train_loader, desc="Training")
|
| 108 |
+
|
| 109 |
+
for batch in progress_bar:
|
| 110 |
+
# Move batch to device
|
| 111 |
+
input_ids = batch['input_ids'].to(device)
|
| 112 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 113 |
+
labels = batch['labels'].to(device)
|
| 114 |
+
|
| 115 |
+
# Forward pass
|
| 116 |
+
outputs = model(
|
| 117 |
+
input_ids=input_ids,
|
| 118 |
+
attention_mask=attention_mask,
|
| 119 |
+
labels=labels
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
loss = outputs.loss
|
| 123 |
+
total_loss += loss.item()
|
| 124 |
+
|
| 125 |
+
# Backward pass
|
| 126 |
+
optimizer.zero_grad()
|
| 127 |
+
loss.backward()
|
| 128 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 129 |
+
optimizer.step()
|
| 130 |
+
scheduler.step()
|
| 131 |
+
|
| 132 |
+
# Track predictions
|
| 133 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 134 |
+
predictions.extend(preds.cpu().numpy())
|
| 135 |
+
true_labels.extend(labels.cpu().numpy())
|
| 136 |
+
|
| 137 |
+
# Update progress bar
|
| 138 |
+
progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
|
| 139 |
+
|
| 140 |
+
avg_loss = total_loss / len(train_loader)
|
| 141 |
+
accuracy = accuracy_score(true_labels, predictions)
|
| 142 |
+
|
| 143 |
+
return avg_loss, accuracy
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False):
|
| 147 |
+
"""Evaluate model on validation/test set with optional detailed analysis"""
|
| 148 |
+
model.eval()
|
| 149 |
+
total_loss = 0
|
| 150 |
+
predictions = []
|
| 151 |
+
true_labels = []
|
| 152 |
+
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
for batch in tqdm(data_loader, desc=desc):
|
| 155 |
+
input_ids = batch['input_ids'].to(device)
|
| 156 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 157 |
+
labels = batch['labels'].to(device)
|
| 158 |
+
|
| 159 |
+
outputs = model(
|
| 160 |
+
input_ids=input_ids,
|
| 161 |
+
attention_mask=attention_mask,
|
| 162 |
+
labels=labels
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
total_loss += outputs.loss.item()
|
| 166 |
+
|
| 167 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 168 |
+
predictions.extend(preds.cpu().numpy())
|
| 169 |
+
true_labels.extend(labels.cpu().numpy())
|
| 170 |
+
|
| 171 |
+
avg_loss = total_loss / len(data_loader)
|
| 172 |
+
accuracy = accuracy_score(true_labels, predictions)
|
| 173 |
+
f1 = f1_score(true_labels, predictions, average='macro')
|
| 174 |
+
|
| 175 |
+
# Run detailed analysis if requested
|
| 176 |
+
if show_analysis and languages is not None:
|
| 177 |
+
print("\n" + "-"*70)
|
| 178 |
+
print("VALIDATION ANALYSIS")
|
| 179 |
+
print("-"*70)
|
| 180 |
+
|
| 181 |
+
# Analyze by category
|
| 182 |
+
analyze_performance_by_category(predictions, true_labels)
|
| 183 |
+
|
| 184 |
+
# Analyze by language (top 5)
|
| 185 |
+
analyze_performance_by_language(predictions, true_labels, languages, top_n=5)
|
| 186 |
+
|
| 187 |
+
# Analyze combinations (top 10)
|
| 188 |
+
analyze_language_category_combinations(predictions, true_labels, languages, top_n=10)
|
| 189 |
+
|
| 190 |
+
print("-"*70)
|
| 191 |
+
|
| 192 |
+
return avg_loss, accuracy, f1, predictions, true_labels
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_data(file_path):
|
| 196 |
+
"""Load annotated data from JSONL file"""
|
| 197 |
+
print(f"Loading data from: {file_path}\n")
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
df = pd.read_json(file_path, lines=True)
|
| 201 |
+
print(f"✓ Loaded {len(df)} annotated examples")
|
| 202 |
+
|
| 203 |
+
# Check required columns
|
| 204 |
+
if 'question' not in df.columns:
|
| 205 |
+
raise ValueError("Missing 'question' column")
|
| 206 |
+
|
| 207 |
+
# Determine label column
|
| 208 |
+
if 'label_id' in df.columns:
|
| 209 |
+
label_col = 'label_id'
|
| 210 |
+
elif 'ensemble_prediction' in df.columns:
|
| 211 |
+
# Convert category names to IDs
|
| 212 |
+
df['label_id'] = df['ensemble_prediction'].map(LABEL2ID)
|
| 213 |
+
label_col = 'label_id'
|
| 214 |
+
elif 'label' in df.columns:
|
| 215 |
+
label_col = 'label'
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')")
|
| 218 |
+
|
| 219 |
+
# Remove any rows with missing labels
|
| 220 |
+
df = df.dropna(subset=['question', label_col])
|
| 221 |
+
|
| 222 |
+
print(f"✓ Data cleaned: {len(df)} examples with valid labels")
|
| 223 |
+
|
| 224 |
+
# Show statistics
|
| 225 |
+
print("\nLabel distribution:")
|
| 226 |
+
label_counts = df[label_col].value_counts().sort_index()
|
| 227 |
+
for label_id, count in label_counts.items():
|
| 228 |
+
cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}")
|
| 229 |
+
print(f" {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)")
|
| 230 |
+
|
| 231 |
+
# Prepare final dataset with language info
|
| 232 |
+
questions = df['question'].tolist()
|
| 233 |
+
labels = df[label_col].astype(int).tolist()
|
| 234 |
+
languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df)
|
| 235 |
+
|
| 236 |
+
print(f"\n✓ Prepared {len(questions)} question-label pairs")
|
| 237 |
+
|
| 238 |
+
return questions, labels, languages
|
| 239 |
+
|
| 240 |
+
except FileNotFoundError:
|
| 241 |
+
print(f"❌ Error: File not found: {file_path}")
|
| 242 |
+
raise
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"❌ Error loading data: {e}")
|
| 245 |
+
raise
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def create_data_splits(questions, labels, test_size=0.2, val_size=0.1):
|
| 249 |
+
"""Create train/val/test splits"""
|
| 250 |
+
print("\nCreating data splits...")
|
| 251 |
+
|
| 252 |
+
# First split: separate test set
|
| 253 |
+
train_val_questions, test_questions, train_val_labels, test_labels = train_test_split(
|
| 254 |
+
questions,
|
| 255 |
+
labels,
|
| 256 |
+
test_size=test_size,
|
| 257 |
+
random_state=RANDOM_SEED,
|
| 258 |
+
stratify=labels
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Second split: separate validation from training
|
| 262 |
+
train_questions, val_questions, train_labels, val_labels = train_test_split(
|
| 263 |
+
train_val_questions,
|
| 264 |
+
train_val_labels,
|
| 265 |
+
test_size=val_size / (1 - test_size),
|
| 266 |
+
random_state=RANDOM_SEED,
|
| 267 |
+
stratify=train_val_labels
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
print(f"\nData splits:")
|
| 271 |
+
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
|
| 272 |
+
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
|
| 273 |
+
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
|
| 274 |
+
print(f" Total: {len(questions):4d} examples")
|
| 275 |
+
|
| 276 |
+
# Verify class distribution
|
| 277 |
+
print("\nClass distribution per split:")
|
| 278 |
+
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
|
| 279 |
+
counts = Counter(split_labels)
|
| 280 |
+
print(f"\n{split_name}:")
|
| 281 |
+
for label_id in sorted(counts.keys()):
|
| 282 |
+
cat_name = ID2LABEL[label_id]
|
| 283 |
+
print(f" {cat_name:20s}: {counts[label_id]:3d}")
|
| 284 |
+
|
| 285 |
+
return train_questions, val_questions, test_questions, train_labels, val_labels, test_labels
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def plot_training_curves(history, best_val_f1, output_dir):
|
| 289 |
+
"""Plot and save training curves"""
|
| 290 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 291 |
+
|
| 292 |
+
epochs = range(1, len(history['train_loss']) + 1)
|
| 293 |
+
|
| 294 |
+
# Plot 1: Loss
|
| 295 |
+
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
| 296 |
+
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
| 297 |
+
axes[0].set_xlabel('Epoch')
|
| 298 |
+
axes[0].set_ylabel('Loss')
|
| 299 |
+
axes[0].set_title('Training and Validation Loss')
|
| 300 |
+
axes[0].legend()
|
| 301 |
+
axes[0].grid(True, alpha=0.3)
|
| 302 |
+
|
| 303 |
+
# Plot 2: Accuracy
|
| 304 |
+
axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
|
| 305 |
+
axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2)
|
| 306 |
+
axes[1].set_xlabel('Epoch')
|
| 307 |
+
axes[1].set_ylabel('Accuracy')
|
| 308 |
+
axes[1].set_title('Training and Validation Accuracy')
|
| 309 |
+
axes[1].legend()
|
| 310 |
+
axes[1].grid(True, alpha=0.3)
|
| 311 |
+
|
| 312 |
+
# Plot 3: F1 Score
|
| 313 |
+
axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2)
|
| 314 |
+
axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}')
|
| 315 |
+
axes[2].set_xlabel('Epoch')
|
| 316 |
+
axes[2].set_ylabel('F1 Score')
|
| 317 |
+
axes[2].set_title('Validation F1 Score')
|
| 318 |
+
axes[2].legend()
|
| 319 |
+
axes[2].grid(True, alpha=0.3)
|
| 320 |
+
|
| 321 |
+
plt.tight_layout()
|
| 322 |
+
plot_file = os.path.join(output_dir, 'training_curves.png')
|
| 323 |
+
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
| 324 |
+
plt.close()
|
| 325 |
+
|
| 326 |
+
print(f"✓ Training curves saved to: {plot_file}")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def analyze_performance_by_language(predictions, true_labels, languages, top_n=10):
|
| 330 |
+
"""Analyze and print performance by language"""
|
| 331 |
+
from collections import defaultdict
|
| 332 |
+
|
| 333 |
+
lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
|
| 334 |
+
|
| 335 |
+
for pred, true, lang in zip(predictions, true_labels, languages):
|
| 336 |
+
lang_stats[lang]['total'] += 1
|
| 337 |
+
if pred == true:
|
| 338 |
+
lang_stats[lang]['correct'] += 1
|
| 339 |
+
|
| 340 |
+
# Calculate accuracy per language
|
| 341 |
+
lang_accuracies = []
|
| 342 |
+
for lang, stats in lang_stats.items():
|
| 343 |
+
if stats['total'] >= 5: # Only show languages with at least 5 examples
|
| 344 |
+
acc = stats['correct'] / stats['total']
|
| 345 |
+
lang_accuracies.append({
|
| 346 |
+
'language': lang,
|
| 347 |
+
'accuracy': acc,
|
| 348 |
+
'correct': stats['correct'],
|
| 349 |
+
'total': stats['total'],
|
| 350 |
+
'errors': stats['total'] - stats['correct']
|
| 351 |
+
})
|
| 352 |
+
|
| 353 |
+
lang_accuracies.sort(key=lambda x: x['accuracy'])
|
| 354 |
+
|
| 355 |
+
print(f"\n{'='*70}")
|
| 356 |
+
print(f"WORST {top_n} LANGUAGES (with >= 5 examples)")
|
| 357 |
+
print(f"{'='*70}")
|
| 358 |
+
print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
|
| 359 |
+
print(f"{'-'*70}")
|
| 360 |
+
|
| 361 |
+
for item in lang_accuracies[:top_n]:
|
| 362 |
+
print(f"{item['language']:<12} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}")
|
| 363 |
+
|
| 364 |
+
return lang_stats, lang_accuracies
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def analyze_performance_by_category(predictions, true_labels):
|
| 368 |
+
"""Analyze and print performance by category"""
|
| 369 |
+
from collections import defaultdict
|
| 370 |
+
|
| 371 |
+
cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
|
| 372 |
+
|
| 373 |
+
for pred, true in zip(predictions, true_labels):
|
| 374 |
+
cat_stats[true]['total'] += 1
|
| 375 |
+
if pred == true:
|
| 376 |
+
cat_stats[true]['correct'] += 1
|
| 377 |
+
|
| 378 |
+
cat_accuracies = []
|
| 379 |
+
for cat_id, stats in cat_stats.items():
|
| 380 |
+
acc = stats['correct'] / stats['total']
|
| 381 |
+
cat_accuracies.append({
|
| 382 |
+
'category': ID2LABEL[cat_id],
|
| 383 |
+
'accuracy': acc,
|
| 384 |
+
'correct': stats['correct'],
|
| 385 |
+
'total': stats['total'],
|
| 386 |
+
'errors': stats['total'] - stats['correct']
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
cat_accuracies.sort(key=lambda x: x['accuracy'])
|
| 390 |
+
|
| 391 |
+
print(f"\n{'='*70}")
|
| 392 |
+
print(f"PERFORMANCE BY CATEGORY")
|
| 393 |
+
print(f"{'='*70}")
|
| 394 |
+
print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
|
| 395 |
+
print(f"{'-'*70}")
|
| 396 |
+
|
| 397 |
+
for item in cat_accuracies:
|
| 398 |
+
print(f"{item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}")
|
| 399 |
+
|
| 400 |
+
return cat_stats, cat_accuracies
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15):
|
| 404 |
+
"""Analyze performance by (language, category) combinations"""
|
| 405 |
+
from collections import defaultdict
|
| 406 |
+
|
| 407 |
+
combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
|
| 408 |
+
|
| 409 |
+
for pred, true, lang in zip(predictions, true_labels, languages):
|
| 410 |
+
key = (lang, ID2LABEL[true])
|
| 411 |
+
combo_stats[key]['total'] += 1
|
| 412 |
+
if pred == true:
|
| 413 |
+
combo_stats[key]['correct'] += 1
|
| 414 |
+
|
| 415 |
+
combo_accuracies = []
|
| 416 |
+
for (lang, cat), stats in combo_stats.items():
|
| 417 |
+
if stats['total'] >= 3: # Only show combinations with at least 3 examples
|
| 418 |
+
acc = stats['correct'] / stats['total']
|
| 419 |
+
combo_accuracies.append({
|
| 420 |
+
'language': lang,
|
| 421 |
+
'category': cat,
|
| 422 |
+
'accuracy': acc,
|
| 423 |
+
'correct': stats['correct'],
|
| 424 |
+
'total': stats['total'],
|
| 425 |
+
'errors': stats['total'] - stats['correct']
|
| 426 |
+
})
|
| 427 |
+
|
| 428 |
+
combo_accuracies.sort(key=lambda x: x['accuracy'])
|
| 429 |
+
|
| 430 |
+
print(f"\n{'='*80}")
|
| 431 |
+
print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)")
|
| 432 |
+
print(f"{'='*80}")
|
| 433 |
+
print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}")
|
| 434 |
+
print(f"{'-'*80}")
|
| 435 |
+
|
| 436 |
+
for item in combo_accuracies[:top_n]:
|
| 437 |
+
print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>6} {item['total']:>6}")
|
| 438 |
+
|
| 439 |
+
return combo_stats, combo_accuracies
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def plot_confusion_matrix(test_true, test_preds, output_dir):
|
| 443 |
+
"""Plot and save confusion matrix"""
|
| 444 |
+
cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES))))
|
| 445 |
+
|
| 446 |
+
plt.figure(figsize=(12, 10))
|
| 447 |
+
sns.heatmap(
|
| 448 |
+
cm,
|
| 449 |
+
annot=True,
|
| 450 |
+
fmt='d',
|
| 451 |
+
cmap='Blues',
|
| 452 |
+
xticklabels=NFQA_CATEGORIES,
|
| 453 |
+
yticklabels=NFQA_CATEGORIES,
|
| 454 |
+
cbar_kws={'label': 'Count'}
|
| 455 |
+
)
|
| 456 |
+
plt.xlabel('Predicted Category')
|
| 457 |
+
plt.ylabel('True Category')
|
| 458 |
+
plt.title('Confusion Matrix - Test Set')
|
| 459 |
+
plt.xticks(rotation=45, ha='right')
|
| 460 |
+
plt.yticks(rotation=0)
|
| 461 |
+
plt.tight_layout()
|
| 462 |
+
|
| 463 |
+
cm_file = os.path.join(output_dir, 'confusion_matrix.png')
|
| 464 |
+
plt.savefig(cm_file, dpi=300, bbox_inches='tight')
|
| 465 |
+
plt.close()
|
| 466 |
+
|
| 467 |
+
print(f"✓ Confusion matrix saved to: {cm_file}")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def main():
|
| 471 |
+
parser = argparse.ArgumentParser(description='Train NFQA Classification Model')
|
| 472 |
+
|
| 473 |
+
# Data arguments - either single input file OR separate train/val/test files
|
| 474 |
+
parser.add_argument('--input', type=str,
|
| 475 |
+
help='Input JSONL file with annotated data (will be split automatically)')
|
| 476 |
+
parser.add_argument('--train', type=str,
|
| 477 |
+
help='Training set JSONL file (use with --val and --test)')
|
| 478 |
+
parser.add_argument('--val', type=str,
|
| 479 |
+
help='Validation set JSONL file (use with --train and --test)')
|
| 480 |
+
parser.add_argument('--test', type=str,
|
| 481 |
+
help='Test set JSONL file (use with --train and --val)')
|
| 482 |
+
parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained',
|
| 483 |
+
help='Output directory for model and results')
|
| 484 |
+
|
| 485 |
+
# Model arguments
|
| 486 |
+
parser.add_argument('--model-name', type=str, default='xlm-roberta-base',
|
| 487 |
+
help='Pretrained model name (default: xlm-roberta-base)')
|
| 488 |
+
parser.add_argument('--max-length', type=int, default=128,
|
| 489 |
+
help='Maximum sequence length (default: 128)')
|
| 490 |
+
|
| 491 |
+
# Training arguments
|
| 492 |
+
parser.add_argument('--batch-size', type=int, default=16,
|
| 493 |
+
help='Batch size (default: 16)')
|
| 494 |
+
parser.add_argument('--epochs', type=int, default=10,
|
| 495 |
+
help='Number of epochs (default: 10)')
|
| 496 |
+
parser.add_argument('--learning-rate', type=float, default=2e-5,
|
| 497 |
+
help='Learning rate (default: 2e-5)')
|
| 498 |
+
parser.add_argument('--warmup-steps', type=int, default=500,
|
| 499 |
+
help='Warmup steps (default: 500)')
|
| 500 |
+
parser.add_argument('--weight-decay', type=float, default=0.01,
|
| 501 |
+
help='Weight decay (default: 0.01)')
|
| 502 |
+
parser.add_argument('--dropout', type=float, default=0.1,
|
| 503 |
+
help='Dropout probability (default: 0.1)')
|
| 504 |
+
|
| 505 |
+
# Split arguments
|
| 506 |
+
parser.add_argument('--test-size', type=float, default=0.2,
|
| 507 |
+
help='Test set size (default: 0.2)')
|
| 508 |
+
parser.add_argument('--val-size', type=float, default=0.1,
|
| 509 |
+
help='Validation set size (default: 0.1)')
|
| 510 |
+
|
| 511 |
+
# Device argument
|
| 512 |
+
parser.add_argument('--device', type=str, default='auto',
|
| 513 |
+
help='Device to use: cuda, cpu, or auto (default: auto)')
|
| 514 |
+
|
| 515 |
+
args = parser.parse_args()
|
| 516 |
+
|
| 517 |
+
# Validate arguments
|
| 518 |
+
has_single_input = args.input is not None
|
| 519 |
+
has_split_inputs = all([args.train, args.val, args.test])
|
| 520 |
+
|
| 521 |
+
if not has_single_input and not has_split_inputs:
|
| 522 |
+
parser.error("Either --input OR (--train, --val, --test) must be provided")
|
| 523 |
+
|
| 524 |
+
if has_single_input and has_split_inputs:
|
| 525 |
+
parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.")
|
| 526 |
+
|
| 527 |
+
# Print configuration
|
| 528 |
+
print("="*80)
|
| 529 |
+
print("NFQA MODEL TRAINING")
|
| 530 |
+
print("="*80)
|
| 531 |
+
if has_single_input:
|
| 532 |
+
print(f"Input file: {args.input}")
|
| 533 |
+
print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})")
|
| 534 |
+
else:
|
| 535 |
+
print(f"Train file: {args.train}")
|
| 536 |
+
print(f"Val file: {args.val}")
|
| 537 |
+
print(f"Test file: {args.test}")
|
| 538 |
+
print(f"Data splitting: manual (pre-split)")
|
| 539 |
+
print(f"Output directory: {args.output_dir}")
|
| 540 |
+
print(f"Model: {args.model_name}")
|
| 541 |
+
print(f"Epochs: {args.epochs}")
|
| 542 |
+
print(f"Batch size: {args.batch_size}")
|
| 543 |
+
print(f"Learning rate: {args.learning_rate}")
|
| 544 |
+
print(f"Max length: {args.max_length}")
|
| 545 |
+
print(f"Weight decay: {args.weight_decay}")
|
| 546 |
+
print(f"Dropout: {args.dropout}")
|
| 547 |
+
print("="*80 + "\n")
|
| 548 |
+
|
| 549 |
+
# Set device
|
| 550 |
+
if args.device == 'auto':
|
| 551 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 552 |
+
else:
|
| 553 |
+
device = torch.device(args.device)
|
| 554 |
+
|
| 555 |
+
if torch.cuda.is_available():
|
| 556 |
+
torch.cuda.manual_seed_all(RANDOM_SEED)
|
| 557 |
+
|
| 558 |
+
print(f"Device: {device}")
|
| 559 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 560 |
+
if torch.cuda.is_available():
|
| 561 |
+
print(f"CUDA device: {torch.cuda.get_device_name(0)}\n")
|
| 562 |
+
|
| 563 |
+
# Create output directory
|
| 564 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 565 |
+
|
| 566 |
+
# Load data - either from single file or pre-split files
|
| 567 |
+
if has_single_input:
|
| 568 |
+
# Load single file and create splits
|
| 569 |
+
questions, labels, languages = load_data(args.input)
|
| 570 |
+
|
| 571 |
+
# Create splits (stratify by labels, keep languages aligned)
|
| 572 |
+
from sklearn.model_selection import train_test_split
|
| 573 |
+
# First split: separate test set
|
| 574 |
+
train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split(
|
| 575 |
+
questions, labels, languages,
|
| 576 |
+
test_size=args.test_size,
|
| 577 |
+
random_state=RANDOM_SEED,
|
| 578 |
+
stratify=labels
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Second split: separate validation from training
|
| 582 |
+
train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split(
|
| 583 |
+
train_val_questions, train_val_labels, train_val_langs,
|
| 584 |
+
test_size=args.val_size / (1 - args.test_size),
|
| 585 |
+
random_state=RANDOM_SEED,
|
| 586 |
+
stratify=train_val_labels
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
print(f"\nData splits:")
|
| 590 |
+
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
|
| 591 |
+
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
|
| 592 |
+
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
|
| 593 |
+
print(f" Total: {len(questions):4d} examples")
|
| 594 |
+
else:
|
| 595 |
+
# Load pre-split files
|
| 596 |
+
print("Loading pre-split datasets...\n")
|
| 597 |
+
train_questions, train_labels, train_langs = load_data(args.train)
|
| 598 |
+
val_questions, val_labels, val_langs = load_data(args.val)
|
| 599 |
+
test_questions, test_labels, test_langs = load_data(args.test)
|
| 600 |
+
|
| 601 |
+
# Print split summary
|
| 602 |
+
total_examples = len(train_questions) + len(val_questions) + len(test_questions)
|
| 603 |
+
print(f"\nData splits:")
|
| 604 |
+
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)")
|
| 605 |
+
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)")
|
| 606 |
+
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)")
|
| 607 |
+
print(f" Total: {total_examples:4d} examples")
|
| 608 |
+
|
| 609 |
+
# Show class distribution per split
|
| 610 |
+
print("\nClass distribution per split:")
|
| 611 |
+
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
|
| 612 |
+
counts = Counter(split_labels)
|
| 613 |
+
print(f"\n{split_name}:")
|
| 614 |
+
for label_id in sorted(counts.keys()):
|
| 615 |
+
cat_name = ID2LABEL[label_id]
|
| 616 |
+
print(f" {cat_name:20s}: {counts[label_id]:3d}")
|
| 617 |
+
|
| 618 |
+
# Load tokenizer and model
|
| 619 |
+
print(f"\nLoading tokenizer: {args.model_name}")
|
| 620 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 621 |
+
print("✓ Tokenizer loaded")
|
| 622 |
+
|
| 623 |
+
print(f"\nLoading model: {args.model_name}")
|
| 624 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 625 |
+
args.model_name,
|
| 626 |
+
num_labels=len(NFQA_CATEGORIES),
|
| 627 |
+
id2label=ID2LABEL,
|
| 628 |
+
label2id=LABEL2ID,
|
| 629 |
+
hidden_dropout_prob=args.dropout,
|
| 630 |
+
attention_probs_dropout_prob=args.dropout,
|
| 631 |
+
classifier_dropout=args.dropout
|
| 632 |
+
)
|
| 633 |
+
model.to(device)
|
| 634 |
+
|
| 635 |
+
print(f"✓ Model loaded")
|
| 636 |
+
print(f" Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 637 |
+
print(f" Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 638 |
+
|
| 639 |
+
# Create datasets
|
| 640 |
+
print("\nCreating datasets...")
|
| 641 |
+
train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length)
|
| 642 |
+
val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length)
|
| 643 |
+
test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length)
|
| 644 |
+
|
| 645 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
| 646 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
|
| 647 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
|
| 648 |
+
|
| 649 |
+
print(f"✓ Datasets created")
|
| 650 |
+
print(f" Train: {len(train_dataset)} examples ({len(train_loader)} batches)")
|
| 651 |
+
print(f" Val: {len(val_dataset)} examples ({len(val_loader)} batches)")
|
| 652 |
+
print(f" Test: {len(test_dataset)} examples ({len(test_loader)} batches)")
|
| 653 |
+
|
| 654 |
+
# Setup optimizer and scheduler
|
| 655 |
+
optimizer = AdamW(
|
| 656 |
+
model.parameters(),
|
| 657 |
+
lr=args.learning_rate,
|
| 658 |
+
weight_decay=args.weight_decay
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
total_steps = len(train_loader) * args.epochs
|
| 662 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 663 |
+
optimizer,
|
| 664 |
+
num_warmup_steps=args.warmup_steps,
|
| 665 |
+
num_training_steps=total_steps
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
print(f"\n✓ Optimizer and scheduler configured")
|
| 669 |
+
print(f" Total training steps: {total_steps}")
|
| 670 |
+
print(f" Warmup steps: {args.warmup_steps}")
|
| 671 |
+
|
| 672 |
+
# Training loop
|
| 673 |
+
history = {
|
| 674 |
+
'train_loss': [],
|
| 675 |
+
'train_accuracy': [],
|
| 676 |
+
'val_loss': [],
|
| 677 |
+
'val_accuracy': [],
|
| 678 |
+
'val_f1': []
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
best_val_f1 = 0
|
| 682 |
+
best_epoch = 0
|
| 683 |
+
|
| 684 |
+
print("\n" + "="*80)
|
| 685 |
+
print("STARTING TRAINING")
|
| 686 |
+
print("="*80 + "\n")
|
| 687 |
+
|
| 688 |
+
for epoch in range(args.epochs):
|
| 689 |
+
print(f"\nEpoch {epoch + 1}/{args.epochs}")
|
| 690 |
+
print("-" * 80)
|
| 691 |
+
|
| 692 |
+
# Train
|
| 693 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
|
| 694 |
+
|
| 695 |
+
# Validate with detailed analysis
|
| 696 |
+
val_loss, val_acc, val_f1, val_preds, val_true = evaluate(
|
| 697 |
+
model, val_loader, device,
|
| 698 |
+
languages=val_langs,
|
| 699 |
+
desc="Validating",
|
| 700 |
+
show_analysis=True
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# Update history
|
| 704 |
+
history['train_loss'].append(train_loss)
|
| 705 |
+
history['train_accuracy'].append(train_acc)
|
| 706 |
+
history['val_loss'].append(val_loss)
|
| 707 |
+
history['val_accuracy'].append(val_acc)
|
| 708 |
+
history['val_f1'].append(val_f1)
|
| 709 |
+
|
| 710 |
+
# Print metrics
|
| 711 |
+
print(f"\nEpoch {epoch + 1} Summary:")
|
| 712 |
+
print(f" Train Loss: {train_loss:.4f}")
|
| 713 |
+
print(f" Train Accuracy: {train_acc:.4f}")
|
| 714 |
+
print(f" Val Loss: {val_loss:.4f}")
|
| 715 |
+
print(f" Val Accuracy: {val_acc:.4f}")
|
| 716 |
+
print(f" Val F1 (Macro): {val_f1:.4f}")
|
| 717 |
+
|
| 718 |
+
# Save best model
|
| 719 |
+
if val_f1 > best_val_f1:
|
| 720 |
+
best_val_f1 = val_f1
|
| 721 |
+
best_epoch = epoch + 1
|
| 722 |
+
|
| 723 |
+
# Save model
|
| 724 |
+
model_path = os.path.join(args.output_dir, 'best_model')
|
| 725 |
+
model.save_pretrained(model_path)
|
| 726 |
+
tokenizer.save_pretrained(model_path)
|
| 727 |
+
|
| 728 |
+
print(f" ✓ New best model saved! (F1: {val_f1:.4f})")
|
| 729 |
+
|
| 730 |
+
print("\n" + "="*80)
|
| 731 |
+
print("TRAINING COMPLETE")
|
| 732 |
+
print("="*80)
|
| 733 |
+
print(f"Best epoch: {best_epoch}")
|
| 734 |
+
print(f"Best validation F1: {best_val_f1:.4f}")
|
| 735 |
+
print("="*80)
|
| 736 |
+
|
| 737 |
+
# Save training history
|
| 738 |
+
history_file = os.path.join(args.output_dir, 'training_history.json')
|
| 739 |
+
with open(history_file, 'w') as f:
|
| 740 |
+
json.dump(history, f, indent=2)
|
| 741 |
+
print(f"\n✓ Training history saved to: {history_file}")
|
| 742 |
+
|
| 743 |
+
# Save final model
|
| 744 |
+
final_model_path = os.path.join(args.output_dir, 'final_model')
|
| 745 |
+
model.save_pretrained(final_model_path)
|
| 746 |
+
tokenizer.save_pretrained(final_model_path)
|
| 747 |
+
print(f"✓ Final model saved to: {final_model_path}")
|
| 748 |
+
|
| 749 |
+
# Plot training curves
|
| 750 |
+
plot_training_curves(history, best_val_f1, args.output_dir)
|
| 751 |
+
|
| 752 |
+
# Load best model and evaluate on test set
|
| 753 |
+
print("\nLoading best model for final evaluation...")
|
| 754 |
+
best_model_path = os.path.join(args.output_dir, 'best_model')
|
| 755 |
+
model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
|
| 756 |
+
model.to(device)
|
| 757 |
+
|
| 758 |
+
test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing")
|
| 759 |
+
|
| 760 |
+
print("\n" + "="*80)
|
| 761 |
+
print("FINAL TEST SET RESULTS")
|
| 762 |
+
print("="*80)
|
| 763 |
+
print(f"Test Loss: {test_loss:.4f}")
|
| 764 |
+
print(f"Test Accuracy: {test_acc:.4f}")
|
| 765 |
+
print(f"Test F1 (Macro): {test_f1:.4f}")
|
| 766 |
+
print("="*80)
|
| 767 |
+
|
| 768 |
+
# Classification report
|
| 769 |
+
print("\n" + "="*80)
|
| 770 |
+
print("PER-CATEGORY PERFORMANCE")
|
| 771 |
+
print("="*80 + "\n")
|
| 772 |
+
|
| 773 |
+
report = classification_report(
|
| 774 |
+
test_true,
|
| 775 |
+
test_preds,
|
| 776 |
+
labels=list(range(len(NFQA_CATEGORIES))),
|
| 777 |
+
target_names=NFQA_CATEGORIES,
|
| 778 |
+
zero_division=0
|
| 779 |
+
)
|
| 780 |
+
print(report)
|
| 781 |
+
|
| 782 |
+
# Save report
|
| 783 |
+
report_file = os.path.join(args.output_dir, 'classification_report.txt')
|
| 784 |
+
with open(report_file, 'w') as f:
|
| 785 |
+
f.write(report)
|
| 786 |
+
print(f"✓ Classification report saved to: {report_file}")
|
| 787 |
+
|
| 788 |
+
# Plot confusion matrix
|
| 789 |
+
plot_confusion_matrix(test_true, test_preds, args.output_dir)
|
| 790 |
+
|
| 791 |
+
# Detailed performance analysis
|
| 792 |
+
print("\n" + "="*80)
|
| 793 |
+
print("DETAILED PERFORMANCE ANALYSIS")
|
| 794 |
+
print("="*80)
|
| 795 |
+
|
| 796 |
+
# Analyze by category
|
| 797 |
+
analyze_performance_by_category(test_preds, test_true)
|
| 798 |
+
|
| 799 |
+
# Analyze by language
|
| 800 |
+
analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10)
|
| 801 |
+
|
| 802 |
+
# Analyze language-category combinations
|
| 803 |
+
analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15)
|
| 804 |
+
|
| 805 |
+
print("\n" + "="*80)
|
| 806 |
+
|
| 807 |
+
# Save test results
|
| 808 |
+
test_results = {
|
| 809 |
+
'test_loss': float(test_loss),
|
| 810 |
+
'test_accuracy': float(test_acc),
|
| 811 |
+
'test_f1_macro': float(test_f1),
|
| 812 |
+
'best_epoch': int(best_epoch),
|
| 813 |
+
'best_val_f1': float(best_val_f1),
|
| 814 |
+
'num_train_examples': len(train_questions),
|
| 815 |
+
'num_val_examples': len(val_questions),
|
| 816 |
+
'num_test_examples': len(test_questions),
|
| 817 |
+
'config': {
|
| 818 |
+
'model_name': args.model_name,
|
| 819 |
+
'max_length': args.max_length,
|
| 820 |
+
'batch_size': args.batch_size,
|
| 821 |
+
'learning_rate': args.learning_rate,
|
| 822 |
+
'num_epochs': args.epochs,
|
| 823 |
+
'warmup_steps': args.warmup_steps,
|
| 824 |
+
'weight_decay': args.weight_decay,
|
| 825 |
+
'dropout': args.dropout,
|
| 826 |
+
'data_source': 'pre-split' if has_split_inputs else 'single_file',
|
| 827 |
+
'train_file': args.train if has_split_inputs else args.input,
|
| 828 |
+
'val_file': args.val if has_split_inputs else None,
|
| 829 |
+
'test_file': args.test if has_split_inputs else None,
|
| 830 |
+
'auto_split': not has_split_inputs,
|
| 831 |
+
'test_size': args.test_size if not has_split_inputs else None,
|
| 832 |
+
'val_size': args.val_size if not has_split_inputs else None
|
| 833 |
+
},
|
| 834 |
+
'timestamp': datetime.now().isoformat()
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
results_file = os.path.join(args.output_dir, 'test_results.json')
|
| 838 |
+
with open(results_file, 'w') as f:
|
| 839 |
+
json.dump(test_results, f, indent=2)
|
| 840 |
+
print(f"✓ Test results saved to: {results_file}")
|
| 841 |
+
|
| 842 |
+
# Summary
|
| 843 |
+
print("\n" + "="*80)
|
| 844 |
+
print("TRAINING SUMMARY")
|
| 845 |
+
print("="*80)
|
| 846 |
+
print(f"\nModel: {args.model_name}")
|
| 847 |
+
print(f"Training examples: {len(train_questions)}")
|
| 848 |
+
print(f"Validation examples: {len(val_questions)}")
|
| 849 |
+
print(f"Test examples: {len(test_questions)}")
|
| 850 |
+
print(f"\nBest epoch: {best_epoch}/{args.epochs}")
|
| 851 |
+
print(f"Best validation F1: {best_val_f1:.4f}")
|
| 852 |
+
print(f"\nFinal test results:")
|
| 853 |
+
print(f" Accuracy: {test_acc:.4f}")
|
| 854 |
+
print(f" F1 Score (Macro): {test_f1:.4f}")
|
| 855 |
+
print(f"\nModel saved to: {args.output_dir}")
|
| 856 |
+
print(f"\nGenerated files:")
|
| 857 |
+
print(f" - best_model/ (best checkpoint)")
|
| 858 |
+
print(f" - final_model/ (last epoch)")
|
| 859 |
+
print(f" - training_history.json")
|
| 860 |
+
print(f" - training_curves.png")
|
| 861 |
+
print(f" - test_results.json")
|
| 862 |
+
print(f" - classification_report.txt")
|
| 863 |
+
print(f" - confusion_matrix.png")
|
| 864 |
+
print("\n" + "="*80)
|
| 865 |
+
print("✅ Training complete! Model ready for deployment.")
|
| 866 |
+
print("="*80)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
if __name__ == '__main__':
|
| 870 |
+
main()
|