Text Ranking
sentence-transformers
Safetensors
cross-encoder
reranker
Generated from Trainer
dataset_size:3190
loss:ListNetLoss
custom_code
Instructions to use Pranjal2002/jina_finance_v2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use Pranjal2002/jina_finance_v2 with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("Pranjal2002/jina_finance_v2", trust_remote_code=True) query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Add new CrossEncoder model
Browse files- .gitattributes +1 -0
- README.md +402 -0
- block.py +470 -0
- config.json +51 -0
- configuration_xlm_roberta.py +69 -0
- embedding.py +62 -0
- mha.py +662 -0
- mlp.py +194 -0
- model.safetensors +3 -0
- modeling_xlm_roberta.py +1119 -0
- special_tokens_map.json +51 -0
- tokenizer.json +3 -0
- tokenizer_config.json +55 -0
- xlm_padding.py +218 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- sentence-transformers
|
| 4 |
+
- cross-encoder
|
| 5 |
+
- reranker
|
| 6 |
+
- generated_from_trainer
|
| 7 |
+
- dataset_size:3190
|
| 8 |
+
- loss:ListNetLoss
|
| 9 |
+
base_model: jinaai/jina-reranker-v2-base-multilingual
|
| 10 |
+
pipeline_tag: text-ranking
|
| 11 |
+
library_name: sentence-transformers
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# CrossEncoder based on jinaai/jina-reranker-v2-base-multilingual
|
| 15 |
+
|
| 16 |
+
This is a [Cross Encoder](https://www.sbert.net/docs/cross_encoder/usage/usage.html) model finetuned from [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) using the [sentence-transformers](https://www.SBERT.net) library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
|
| 17 |
+
|
| 18 |
+
## Model Details
|
| 19 |
+
|
| 20 |
+
### Model Description
|
| 21 |
+
- **Model Type:** Cross Encoder
|
| 22 |
+
- **Base model:** [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) <!-- at revision 2f894e63642a95228da19cdd583cd2309983c867 -->
|
| 23 |
+
- **Maximum Sequence Length:** 1024 tokens
|
| 24 |
+
- **Number of Output Labels:** 1 label
|
| 25 |
+
<!-- - **Training Dataset:** Unknown -->
|
| 26 |
+
<!-- - **Language:** Unknown -->
|
| 27 |
+
<!-- - **License:** Unknown -->
|
| 28 |
+
|
| 29 |
+
### Model Sources
|
| 30 |
+
|
| 31 |
+
- **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
|
| 32 |
+
- **Documentation:** [Cross Encoder Documentation](https://www.sbert.net/docs/cross_encoder/usage/usage.html)
|
| 33 |
+
- **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
|
| 34 |
+
- **Hugging Face:** [Cross Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=cross-encoder)
|
| 35 |
+
|
| 36 |
+
## Usage
|
| 37 |
+
|
| 38 |
+
### Direct Usage (Sentence Transformers)
|
| 39 |
+
|
| 40 |
+
First install the Sentence Transformers library:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
pip install -U sentence-transformers
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Then you can load this model and run inference.
|
| 47 |
+
```python
|
| 48 |
+
from sentence_transformers import CrossEncoder
|
| 49 |
+
|
| 50 |
+
# Download from the 🤗 Hub
|
| 51 |
+
model = CrossEncoder("Pranjal2002/jina_finance_v2")
|
| 52 |
+
# Get scores for pairs of texts
|
| 53 |
+
pairs = [
|
| 54 |
+
['What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?', '10-K'],
|
| 55 |
+
['What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?', 'Earnings'],
|
| 56 |
+
['What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?', 'DEF14A'],
|
| 57 |
+
['What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?', '8-K'],
|
| 58 |
+
['What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?', '10-Q'],
|
| 59 |
+
]
|
| 60 |
+
scores = model.predict(pairs)
|
| 61 |
+
print(scores.shape)
|
| 62 |
+
# (5,)
|
| 63 |
+
|
| 64 |
+
# Or rank different texts based on similarity to a single text
|
| 65 |
+
ranks = model.rank(
|
| 66 |
+
'What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?',
|
| 67 |
+
[
|
| 68 |
+
'10-K',
|
| 69 |
+
'Earnings',
|
| 70 |
+
'DEF14A',
|
| 71 |
+
'8-K',
|
| 72 |
+
'10-Q',
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
<!--
|
| 79 |
+
### Direct Usage (Transformers)
|
| 80 |
+
|
| 81 |
+
<details><summary>Click to see the direct usage in Transformers</summary>
|
| 82 |
+
|
| 83 |
+
</details>
|
| 84 |
+
-->
|
| 85 |
+
|
| 86 |
+
<!--
|
| 87 |
+
### Downstream Usage (Sentence Transformers)
|
| 88 |
+
|
| 89 |
+
You can finetune this model on your own dataset.
|
| 90 |
+
|
| 91 |
+
<details><summary>Click to expand</summary>
|
| 92 |
+
|
| 93 |
+
</details>
|
| 94 |
+
-->
|
| 95 |
+
|
| 96 |
+
<!--
|
| 97 |
+
### Out-of-Scope Use
|
| 98 |
+
|
| 99 |
+
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
| 100 |
+
-->
|
| 101 |
+
|
| 102 |
+
<!--
|
| 103 |
+
## Bias, Risks and Limitations
|
| 104 |
+
|
| 105 |
+
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
| 106 |
+
-->
|
| 107 |
+
|
| 108 |
+
<!--
|
| 109 |
+
### Recommendations
|
| 110 |
+
|
| 111 |
+
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
| 112 |
+
-->
|
| 113 |
+
|
| 114 |
+
## Training Details
|
| 115 |
+
|
| 116 |
+
### Training Dataset
|
| 117 |
+
|
| 118 |
+
#### Unnamed Dataset
|
| 119 |
+
|
| 120 |
+
* Size: 3,190 training samples
|
| 121 |
+
* Columns: <code>query</code>, <code>docs</code>, and <code>labels</code>
|
| 122 |
+
* Approximate statistics based on the first 1000 samples:
|
| 123 |
+
| | query | docs | labels |
|
| 124 |
+
|:--------|:-------------------------------------------------------------------------------------------------|:-----------------------------------|:-----------------------------------|
|
| 125 |
+
| type | string | list | list |
|
| 126 |
+
| details | <ul><li>min: 55 characters</li><li>mean: 103.12 characters</li><li>max: 180 characters</li></ul> | <ul><li>size: 5 elements</li></ul> | <ul><li>size: 5 elements</li></ul> |
|
| 127 |
+
* Samples:
|
| 128 |
+
| query | docs | labels |
|
| 129 |
+
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------|:-----------------------------|
|
| 130 |
+
| <code>What year over year growth rate was shown for paid memberships in the same table</code> | <code>['10-Q', '10-K', '8-K', 'Earnings', 'DEF14A']</code> | <code>[4, 3, 2, 1, 0]</code> |
|
| 131 |
+
| <code>How did non‑GAAP EPS growth align with the incentive metrics set for management?</code> | <code>['DEF14A', '8-K', '10-K', '10-Q', 'Earnings']</code> | <code>[2, 1, 0, 0, 0]</code> |
|
| 132 |
+
| <code>What questions were raised regarding Xcel Energy Inc.’s risk factors and mitigation plans related to the integration of renewable energy sources into their grid?</code> | <code>['10-K', 'Earnings', '8-K', '10-Q', 'DEF14A']</code> | <code>[4, 3, 2, 1, 0]</code> |
|
| 133 |
+
* Loss: [<code>ListNetLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#listnetloss) with these parameters:
|
| 134 |
+
```json
|
| 135 |
+
{
|
| 136 |
+
"activation_fn": "torch.nn.modules.linear.Identity",
|
| 137 |
+
"mini_batch_size": null
|
| 138 |
+
}
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### Evaluation Dataset
|
| 142 |
+
|
| 143 |
+
#### Unnamed Dataset
|
| 144 |
+
|
| 145 |
+
* Size: 798 evaluation samples
|
| 146 |
+
* Columns: <code>query</code>, <code>docs</code>, and <code>labels</code>
|
| 147 |
+
* Approximate statistics based on the first 798 samples:
|
| 148 |
+
| | query | docs | labels |
|
| 149 |
+
|:--------|:-------------------------------------------------------------------------------------------------|:-----------------------------------|:-----------------------------------|
|
| 150 |
+
| type | string | list | list |
|
| 151 |
+
| details | <ul><li>min: 53 characters</li><li>mean: 102.91 characters</li><li>max: 179 characters</li></ul> | <ul><li>size: 5 elements</li></ul> | <ul><li>size: 5 elements</li></ul> |
|
| 152 |
+
* Samples:
|
| 153 |
+
| query | docs | labels |
|
| 154 |
+
|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------|:-----------------------------|
|
| 155 |
+
| <code>What consolidation trends among competitors are highlighted in disclosures affecting Regions Financial Corporation’s regional banking operations?</code> | <code>['10-K', 'Earnings', 'DEF14A', '8-K', '10-Q']</code> | <code>[4, 3, 2, 1, 0]</code> |
|
| 156 |
+
| <code>How does Pentair manage equity award burn rate or share pool availability?</code> | <code>['10-K', 'DEF14A', '10-Q', 'Earnings', '8-K']</code> | <code>[4, 3, 2, 1, 0]</code> |
|
| 157 |
+
| <code>What key takeaways emerged from Valero Energy Corporation’s most recent earnings announcement?</code> | <code>['10-Q', '10-K', 'Earnings', '8-K', 'DEF14A']</code> | <code>[4, 3, 2, 1, 0]</code> |
|
| 158 |
+
* Loss: [<code>ListNetLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#listnetloss) with these parameters:
|
| 159 |
+
```json
|
| 160 |
+
{
|
| 161 |
+
"activation_fn": "torch.nn.modules.linear.Identity",
|
| 162 |
+
"mini_batch_size": null
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Training Hyperparameters
|
| 167 |
+
#### Non-Default Hyperparameters
|
| 168 |
+
|
| 169 |
+
- `eval_strategy`: steps
|
| 170 |
+
- `per_device_train_batch_size`: 4
|
| 171 |
+
- `per_device_eval_batch_size`: 4
|
| 172 |
+
- `gradient_accumulation_steps`: 2
|
| 173 |
+
- `learning_rate`: 2e-05
|
| 174 |
+
- `num_train_epochs`: 5
|
| 175 |
+
- `warmup_steps`: 100
|
| 176 |
+
- `bf16`: True
|
| 177 |
+
- `load_best_model_at_end`: True
|
| 178 |
+
- `optim`: adamw_torch
|
| 179 |
+
|
| 180 |
+
#### All Hyperparameters
|
| 181 |
+
<details><summary>Click to expand</summary>
|
| 182 |
+
|
| 183 |
+
- `overwrite_output_dir`: False
|
| 184 |
+
- `do_predict`: False
|
| 185 |
+
- `eval_strategy`: steps
|
| 186 |
+
- `prediction_loss_only`: True
|
| 187 |
+
- `per_device_train_batch_size`: 4
|
| 188 |
+
- `per_device_eval_batch_size`: 4
|
| 189 |
+
- `per_gpu_train_batch_size`: None
|
| 190 |
+
- `per_gpu_eval_batch_size`: None
|
| 191 |
+
- `gradient_accumulation_steps`: 2
|
| 192 |
+
- `eval_accumulation_steps`: None
|
| 193 |
+
- `torch_empty_cache_steps`: None
|
| 194 |
+
- `learning_rate`: 2e-05
|
| 195 |
+
- `weight_decay`: 0.0
|
| 196 |
+
- `adam_beta1`: 0.9
|
| 197 |
+
- `adam_beta2`: 0.999
|
| 198 |
+
- `adam_epsilon`: 1e-08
|
| 199 |
+
- `max_grad_norm`: 1.0
|
| 200 |
+
- `num_train_epochs`: 5
|
| 201 |
+
- `max_steps`: -1
|
| 202 |
+
- `lr_scheduler_type`: linear
|
| 203 |
+
- `lr_scheduler_kwargs`: {}
|
| 204 |
+
- `warmup_ratio`: 0.0
|
| 205 |
+
- `warmup_steps`: 100
|
| 206 |
+
- `log_level`: passive
|
| 207 |
+
- `log_level_replica`: warning
|
| 208 |
+
- `log_on_each_node`: True
|
| 209 |
+
- `logging_nan_inf_filter`: True
|
| 210 |
+
- `save_safetensors`: True
|
| 211 |
+
- `save_on_each_node`: False
|
| 212 |
+
- `save_only_model`: False
|
| 213 |
+
- `restore_callback_states_from_checkpoint`: False
|
| 214 |
+
- `no_cuda`: False
|
| 215 |
+
- `use_cpu`: False
|
| 216 |
+
- `use_mps_device`: False
|
| 217 |
+
- `seed`: 42
|
| 218 |
+
- `data_seed`: None
|
| 219 |
+
- `jit_mode_eval`: False
|
| 220 |
+
- `use_ipex`: False
|
| 221 |
+
- `bf16`: True
|
| 222 |
+
- `fp16`: False
|
| 223 |
+
- `fp16_opt_level`: O1
|
| 224 |
+
- `half_precision_backend`: auto
|
| 225 |
+
- `bf16_full_eval`: False
|
| 226 |
+
- `fp16_full_eval`: False
|
| 227 |
+
- `tf32`: None
|
| 228 |
+
- `local_rank`: 0
|
| 229 |
+
- `ddp_backend`: None
|
| 230 |
+
- `tpu_num_cores`: None
|
| 231 |
+
- `tpu_metrics_debug`: False
|
| 232 |
+
- `debug`: []
|
| 233 |
+
- `dataloader_drop_last`: False
|
| 234 |
+
- `dataloader_num_workers`: 0
|
| 235 |
+
- `dataloader_prefetch_factor`: None
|
| 236 |
+
- `past_index`: -1
|
| 237 |
+
- `disable_tqdm`: False
|
| 238 |
+
- `remove_unused_columns`: True
|
| 239 |
+
- `label_names`: None
|
| 240 |
+
- `load_best_model_at_end`: True
|
| 241 |
+
- `ignore_data_skip`: False
|
| 242 |
+
- `fsdp`: []
|
| 243 |
+
- `fsdp_min_num_params`: 0
|
| 244 |
+
- `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
|
| 245 |
+
- `fsdp_transformer_layer_cls_to_wrap`: None
|
| 246 |
+
- `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
|
| 247 |
+
- `parallelism_config`: None
|
| 248 |
+
- `deepspeed`: None
|
| 249 |
+
- `label_smoothing_factor`: 0.0
|
| 250 |
+
- `optim`: adamw_torch
|
| 251 |
+
- `optim_args`: None
|
| 252 |
+
- `adafactor`: False
|
| 253 |
+
- `group_by_length`: False
|
| 254 |
+
- `length_column_name`: length
|
| 255 |
+
- `ddp_find_unused_parameters`: None
|
| 256 |
+
- `ddp_bucket_cap_mb`: None
|
| 257 |
+
- `ddp_broadcast_buffers`: False
|
| 258 |
+
- `dataloader_pin_memory`: True
|
| 259 |
+
- `dataloader_persistent_workers`: False
|
| 260 |
+
- `skip_memory_metrics`: True
|
| 261 |
+
- `use_legacy_prediction_loop`: False
|
| 262 |
+
- `push_to_hub`: False
|
| 263 |
+
- `resume_from_checkpoint`: None
|
| 264 |
+
- `hub_model_id`: None
|
| 265 |
+
- `hub_strategy`: every_save
|
| 266 |
+
- `hub_private_repo`: None
|
| 267 |
+
- `hub_always_push`: False
|
| 268 |
+
- `hub_revision`: None
|
| 269 |
+
- `gradient_checkpointing`: False
|
| 270 |
+
- `gradient_checkpointing_kwargs`: None
|
| 271 |
+
- `include_inputs_for_metrics`: False
|
| 272 |
+
- `include_for_metrics`: []
|
| 273 |
+
- `eval_do_concat_batches`: True
|
| 274 |
+
- `fp16_backend`: auto
|
| 275 |
+
- `push_to_hub_model_id`: None
|
| 276 |
+
- `push_to_hub_organization`: None
|
| 277 |
+
- `mp_parameters`:
|
| 278 |
+
- `auto_find_batch_size`: False
|
| 279 |
+
- `full_determinism`: False
|
| 280 |
+
- `torchdynamo`: None
|
| 281 |
+
- `ray_scope`: last
|
| 282 |
+
- `ddp_timeout`: 1800
|
| 283 |
+
- `torch_compile`: False
|
| 284 |
+
- `torch_compile_backend`: None
|
| 285 |
+
- `torch_compile_mode`: None
|
| 286 |
+
- `include_tokens_per_second`: False
|
| 287 |
+
- `include_num_input_tokens_seen`: False
|
| 288 |
+
- `neftune_noise_alpha`: None
|
| 289 |
+
- `optim_target_modules`: None
|
| 290 |
+
- `batch_eval_metrics`: False
|
| 291 |
+
- `eval_on_start`: False
|
| 292 |
+
- `use_liger_kernel`: False
|
| 293 |
+
- `liger_kernel_config`: None
|
| 294 |
+
- `eval_use_gather_object`: False
|
| 295 |
+
- `average_tokens_across_devices`: False
|
| 296 |
+
- `prompts`: None
|
| 297 |
+
- `batch_sampler`: batch_sampler
|
| 298 |
+
- `multi_dataset_batch_sampler`: proportional
|
| 299 |
+
- `router_mapping`: {}
|
| 300 |
+
- `learning_rate_mapping`: {}
|
| 301 |
+
|
| 302 |
+
</details>
|
| 303 |
+
|
| 304 |
+
### Training Logs
|
| 305 |
+
| Epoch | Step | Training Loss | Validation Loss |
|
| 306 |
+
|:--------:|:--------:|:-------------:|:---------------:|
|
| 307 |
+
| 0.1253 | 50 | 1.7131 | - |
|
| 308 |
+
| 0.2506 | 100 | 1.5888 | - |
|
| 309 |
+
| 0.3759 | 150 | 1.49 | - |
|
| 310 |
+
| 0.5013 | 200 | 1.4408 | 1.4397 |
|
| 311 |
+
| 0.6266 | 250 | 1.4225 | - |
|
| 312 |
+
| 0.7519 | 300 | 1.4216 | - |
|
| 313 |
+
| 0.8772 | 350 | 1.4329 | - |
|
| 314 |
+
| 1.0025 | 400 | 1.3996 | 1.4083 |
|
| 315 |
+
| 1.1278 | 450 | 1.4126 | - |
|
| 316 |
+
| 1.2531 | 500 | 1.4002 | - |
|
| 317 |
+
| 1.3784 | 550 | 1.4098 | - |
|
| 318 |
+
| 1.5038 | 600 | 1.3692 | 1.4042 |
|
| 319 |
+
| 1.6291 | 650 | 1.3784 | - |
|
| 320 |
+
| 1.7544 | 700 | 1.4014 | - |
|
| 321 |
+
| 1.8797 | 750 | 1.3815 | - |
|
| 322 |
+
| 2.0050 | 800 | 1.3982 | 1.3910 |
|
| 323 |
+
| 2.1303 | 850 | 1.3864 | - |
|
| 324 |
+
| 2.2556 | 900 | 1.3983 | - |
|
| 325 |
+
| 2.3810 | 950 | 1.3662 | - |
|
| 326 |
+
| 2.5063 | 1000 | 1.3747 | 1.3968 |
|
| 327 |
+
| 2.6316 | 1050 | 1.3739 | - |
|
| 328 |
+
| 2.7569 | 1100 | 1.3687 | - |
|
| 329 |
+
| 2.8822 | 1150 | 1.3858 | - |
|
| 330 |
+
| 3.0075 | 1200 | 1.3847 | 1.3897 |
|
| 331 |
+
| 3.1328 | 1250 | 1.3684 | - |
|
| 332 |
+
| 3.2581 | 1300 | 1.3787 | - |
|
| 333 |
+
| 3.3835 | 1350 | 1.3612 | - |
|
| 334 |
+
| 3.5088 | 1400 | 1.3906 | 1.3920 |
|
| 335 |
+
| 3.6341 | 1450 | 1.3838 | - |
|
| 336 |
+
| 3.7594 | 1500 | 1.3817 | - |
|
| 337 |
+
| 3.8847 | 1550 | 1.3615 | - |
|
| 338 |
+
| **4.01** | **1600** | **1.3978** | **1.3892** |
|
| 339 |
+
| 4.1353 | 1650 | 1.3793 | - |
|
| 340 |
+
| 4.2607 | 1700 | 1.3753 | - |
|
| 341 |
+
| 4.3860 | 1750 | 1.3847 | - |
|
| 342 |
+
| 4.5113 | 1800 | 1.3857 | 1.3887 |
|
| 343 |
+
| 4.6366 | 1850 | 1.3583 | - |
|
| 344 |
+
| 4.7619 | 1900 | 1.3644 | - |
|
| 345 |
+
| 4.8872 | 1950 | 1.3696 | - |
|
| 346 |
+
|
| 347 |
+
* The bold row denotes the saved checkpoint.
|
| 348 |
+
|
| 349 |
+
### Framework Versions
|
| 350 |
+
- Python: 3.12.11
|
| 351 |
+
- Sentence Transformers: 5.1.0
|
| 352 |
+
- Transformers: 4.56.1
|
| 353 |
+
- PyTorch: 2.8.0+cu126
|
| 354 |
+
- Accelerate: 1.10.1
|
| 355 |
+
- Datasets: 4.0.0
|
| 356 |
+
- Tokenizers: 0.22.0
|
| 357 |
+
|
| 358 |
+
## Citation
|
| 359 |
+
|
| 360 |
+
### BibTeX
|
| 361 |
+
|
| 362 |
+
#### Sentence Transformers
|
| 363 |
+
```bibtex
|
| 364 |
+
@inproceedings{reimers-2019-sentence-bert,
|
| 365 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
| 366 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
| 367 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
| 368 |
+
month = "11",
|
| 369 |
+
year = "2019",
|
| 370 |
+
publisher = "Association for Computational Linguistics",
|
| 371 |
+
url = "https://arxiv.org/abs/1908.10084",
|
| 372 |
+
}
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
#### ListNetLoss
|
| 376 |
+
```bibtex
|
| 377 |
+
@inproceedings{cao2007learning,
|
| 378 |
+
title={Learning to Rank: From Pairwise Approach to Listwise Approach},
|
| 379 |
+
author={Cao, Zhe and Qin, Tao and Liu, Tie-Yan and Tsai, Ming-Feng and Li, Hang},
|
| 380 |
+
booktitle={Proceedings of the 24th international conference on Machine learning},
|
| 381 |
+
pages={129--136},
|
| 382 |
+
year={2007}
|
| 383 |
+
}
|
| 384 |
+
```
|
| 385 |
+
|
| 386 |
+
<!--
|
| 387 |
+
## Glossary
|
| 388 |
+
|
| 389 |
+
*Clearly define terms in order to be accessible across audiences.*
|
| 390 |
+
-->
|
| 391 |
+
|
| 392 |
+
<!--
|
| 393 |
+
## Model Card Authors
|
| 394 |
+
|
| 395 |
+
*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
|
| 396 |
+
-->
|
| 397 |
+
|
| 398 |
+
<!--
|
| 399 |
+
## Model Card Contact
|
| 400 |
+
|
| 401 |
+
*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
|
| 402 |
+
-->
|
block.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
| 2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2024, Tri Dao.
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.fx
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from .mha import MHA
|
| 16 |
+
from .mlp import Mlp
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 20 |
+
except ImportError:
|
| 21 |
+
layer_norm_fn, RMSNorm = None, None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stochastic_depth(
|
| 25 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
| 26 |
+
) -> Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
| 29 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
| 30 |
+
branches of residual architectures.
|
| 31 |
+
Args:
|
| 32 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
| 33 |
+
being its batch i.e. a batch with ``N`` rows.
|
| 34 |
+
p (float): probability of the input to be zeroed.
|
| 35 |
+
mode (str): ``"batch"`` or ``"row"``.
|
| 36 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
| 37 |
+
randomly selected rows from the batch.
|
| 38 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
| 41 |
+
"""
|
| 42 |
+
if p < 0.0 or p > 1.0:
|
| 43 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
| 44 |
+
if mode not in ["batch", "row"]:
|
| 45 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
| 46 |
+
if not training or p == 0.0:
|
| 47 |
+
return input
|
| 48 |
+
|
| 49 |
+
survival_rate = 1.0 - p
|
| 50 |
+
if mode == "row":
|
| 51 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
| 52 |
+
else:
|
| 53 |
+
size = [1] * input.ndim
|
| 54 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
| 55 |
+
noise = noise.bernoulli_(survival_rate)
|
| 56 |
+
if survival_rate > 0.0:
|
| 57 |
+
noise.div_(survival_rate)
|
| 58 |
+
return input * noise
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
torch.fx.wrap("stochastic_depth")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class StochasticDepth(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
See :func:`stochastic_depth`.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, p: float, mode: str) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.p = p
|
| 72 |
+
self.mode = mode
|
| 73 |
+
|
| 74 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 75 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
| 76 |
+
|
| 77 |
+
def __repr__(self) -> str:
|
| 78 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
| 79 |
+
return s
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Block(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
dim,
|
| 86 |
+
mixer_cls=None,
|
| 87 |
+
mlp_cls=None,
|
| 88 |
+
norm_cls=nn.LayerNorm,
|
| 89 |
+
dropout_cls=nn.Dropout,
|
| 90 |
+
prenorm=True,
|
| 91 |
+
resid_dropout1=0.0,
|
| 92 |
+
resid_dropout2=0.0,
|
| 93 |
+
drop_path1=0.0,
|
| 94 |
+
drop_path2=0.0,
|
| 95 |
+
fused_dropout_add_ln=False,
|
| 96 |
+
return_residual=False,
|
| 97 |
+
residual_in_fp32=False,
|
| 98 |
+
sequence_parallel=False,
|
| 99 |
+
mark_shared_params=False,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
For prenorm=True, this Block has a slightly different structure compared to a regular
|
| 103 |
+
prenorm Transformer block.
|
| 104 |
+
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
| 105 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 106 |
+
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
| 107 |
+
the hidden_states (output of the MLP) and the residual.
|
| 108 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
| 109 |
+
The residual needs to be provided (except for the very first block).
|
| 110 |
+
|
| 111 |
+
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
| 112 |
+
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
| 113 |
+
|
| 114 |
+
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
| 115 |
+
This is for performance reason: for post-norm architecture, returning the input allows us
|
| 116 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 117 |
+
"""
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.prenorm = prenorm
|
| 120 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 121 |
+
self.return_residual = return_residual
|
| 122 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 123 |
+
if self.residual_in_fp32:
|
| 124 |
+
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
| 125 |
+
if mixer_cls is None:
|
| 126 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
| 127 |
+
if mlp_cls is None:
|
| 128 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
| 129 |
+
self.mixer = mixer_cls(dim)
|
| 130 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
| 131 |
+
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
| 132 |
+
self.norm1 = norm_cls(dim)
|
| 133 |
+
self.mlp = mlp_cls(dim)
|
| 134 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 135 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
| 136 |
+
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
| 137 |
+
self.norm2 = norm_cls(dim)
|
| 138 |
+
|
| 139 |
+
if self.fused_dropout_add_ln:
|
| 140 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
| 141 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
| 142 |
+
self.dropout1, nn.Dropout
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
| 146 |
+
# then the input to each worker in the tensor parallel group will be different.
|
| 147 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
| 148 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
| 149 |
+
# and only use sequence_parallel=False during inference.
|
| 150 |
+
|
| 151 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
| 152 |
+
if sequence_parallel:
|
| 153 |
+
for p in self.norm1.parameters():
|
| 154 |
+
p._sequence_parallel = True
|
| 155 |
+
if hasattr(self, "norm2"):
|
| 156 |
+
for p in self.norm2.parameters():
|
| 157 |
+
p._sequence_parallel = True
|
| 158 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 159 |
+
if mark_shared_params:
|
| 160 |
+
for p in self.norm1.parameters():
|
| 161 |
+
p._shared_params = True
|
| 162 |
+
if hasattr(self, "norm2"):
|
| 163 |
+
for p in self.norm2.parameters():
|
| 164 |
+
p._shared_params = True
|
| 165 |
+
|
| 166 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 167 |
+
return self.mixer.allocate_inference_cache(
|
| 168 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
hidden_states: Tensor,
|
| 174 |
+
residual: Optional[Tensor] = None,
|
| 175 |
+
mixer_subset=None,
|
| 176 |
+
mixer_kwargs=None,
|
| 177 |
+
):
|
| 178 |
+
r"""Pass the input through the encoder layer.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 182 |
+
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
| 183 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 184 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 185 |
+
about the CLS token in the last layer.
|
| 186 |
+
"""
|
| 187 |
+
if self.prenorm:
|
| 188 |
+
if not self.fused_dropout_add_ln:
|
| 189 |
+
dropped = self.drop_path1(self.dropout1(hidden_states))
|
| 190 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 191 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 192 |
+
if self.residual_in_fp32:
|
| 193 |
+
residual = residual.to(torch.float32)
|
| 194 |
+
else:
|
| 195 |
+
if self.drop_path1.p == 0 or not self.training:
|
| 196 |
+
rowscale1 = None
|
| 197 |
+
else:
|
| 198 |
+
rowscale1 = self.drop_path1(
|
| 199 |
+
torch.ones(
|
| 200 |
+
hidden_states.shape[:-1],
|
| 201 |
+
device=hidden_states.device,
|
| 202 |
+
dtype=hidden_states.dtype,
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
hidden_states, residual = layer_norm_fn(
|
| 206 |
+
hidden_states,
|
| 207 |
+
self.norm1.weight,
|
| 208 |
+
self.norm1.bias,
|
| 209 |
+
residual=residual,
|
| 210 |
+
eps=self.norm1.eps,
|
| 211 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 212 |
+
rowscale=rowscale1,
|
| 213 |
+
prenorm=True,
|
| 214 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 215 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 216 |
+
)
|
| 217 |
+
if mixer_kwargs is None:
|
| 218 |
+
mixer_kwargs = {}
|
| 219 |
+
if mixer_subset is not None:
|
| 220 |
+
mixer_kwargs["mixer_subset"] = mixer_subset
|
| 221 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 222 |
+
if mixer_subset is not None:
|
| 223 |
+
residual = residual[:, mixer_subset]
|
| 224 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 225 |
+
if not self.fused_dropout_add_ln:
|
| 226 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 227 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 228 |
+
hidden_states = self.norm2(
|
| 229 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 230 |
+
)
|
| 231 |
+
if self.residual_in_fp32:
|
| 232 |
+
residual = residual.to(torch.float32)
|
| 233 |
+
else:
|
| 234 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 235 |
+
rowscale2 = None
|
| 236 |
+
else:
|
| 237 |
+
rowscale2 = self.drop_path2(
|
| 238 |
+
torch.ones(
|
| 239 |
+
hidden_states.shape[:-1],
|
| 240 |
+
device=hidden_states.device,
|
| 241 |
+
dtype=hidden_states.dtype,
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
hidden_states, residual = layer_norm_fn(
|
| 245 |
+
hidden_states,
|
| 246 |
+
self.norm2.weight,
|
| 247 |
+
self.norm2.bias,
|
| 248 |
+
residual=residual,
|
| 249 |
+
eps=self.norm2.eps,
|
| 250 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 251 |
+
rowscale=rowscale2,
|
| 252 |
+
prenorm=True,
|
| 253 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 254 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 255 |
+
)
|
| 256 |
+
hidden_states = self.mlp(hidden_states)
|
| 257 |
+
return hidden_states, residual
|
| 258 |
+
else:
|
| 259 |
+
assert residual is None
|
| 260 |
+
mixer_out = self.mixer(
|
| 261 |
+
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
| 262 |
+
)
|
| 263 |
+
if self.return_residual: # mixer out is actually a pair here
|
| 264 |
+
mixer_out, hidden_states = mixer_out
|
| 265 |
+
if not self.fused_dropout_add_ln:
|
| 266 |
+
hidden_states = self.norm1(
|
| 267 |
+
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
| 268 |
+
dtype=self.norm1.weight.dtype
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
if self.drop_path1.p == 0 or not self.training:
|
| 273 |
+
rowscale1 = None
|
| 274 |
+
else:
|
| 275 |
+
rowscale1 = self.drop_path1(
|
| 276 |
+
torch.ones(
|
| 277 |
+
mixer_out.shape[:-1],
|
| 278 |
+
device=mixer_out.device,
|
| 279 |
+
dtype=mixer_out.dtype,
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
hidden_states = layer_norm_fn(
|
| 283 |
+
mixer_out,
|
| 284 |
+
self.norm1.weight,
|
| 285 |
+
self.norm1.bias,
|
| 286 |
+
residual=hidden_states,
|
| 287 |
+
eps=self.norm1.eps,
|
| 288 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 289 |
+
rowscale=rowscale1,
|
| 290 |
+
prenorm=False,
|
| 291 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 292 |
+
)
|
| 293 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 294 |
+
mlp_out = self.mlp(hidden_states)
|
| 295 |
+
if self.return_residual: # mlp out is actually a pair here
|
| 296 |
+
mlp_out, hidden_states = mlp_out
|
| 297 |
+
if not self.fused_dropout_add_ln:
|
| 298 |
+
hidden_states = self.norm2(
|
| 299 |
+
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
| 300 |
+
dtype=self.norm2.weight.dtype
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 305 |
+
rowscale2 = None
|
| 306 |
+
else:
|
| 307 |
+
rowscale2 = self.drop_path2(
|
| 308 |
+
torch.ones(
|
| 309 |
+
mlp_out.shape[:-1],
|
| 310 |
+
device=mlp_out.device,
|
| 311 |
+
dtype=mlp_out.dtype,
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
hidden_states = layer_norm_fn(
|
| 315 |
+
mlp_out,
|
| 316 |
+
self.norm2.weight,
|
| 317 |
+
self.norm2.bias,
|
| 318 |
+
residual=hidden_states,
|
| 319 |
+
eps=self.norm2.eps,
|
| 320 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 321 |
+
rowscale=rowscale2,
|
| 322 |
+
prenorm=False,
|
| 323 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 324 |
+
)
|
| 325 |
+
return hidden_states
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ParallelBlock(nn.Module):
|
| 329 |
+
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
| 330 |
+
and PaLM.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
dim,
|
| 336 |
+
mixer_cls=None,
|
| 337 |
+
mlp_cls=None,
|
| 338 |
+
norm_cls=nn.LayerNorm,
|
| 339 |
+
dropout_cls=nn.Dropout,
|
| 340 |
+
resid_dropout1=0.0,
|
| 341 |
+
resid_dropout2=0.0,
|
| 342 |
+
tied_norm=False,
|
| 343 |
+
fused_dropout_add_ln=False,
|
| 344 |
+
residual_in_fp32=False,
|
| 345 |
+
sequence_parallel=False,
|
| 346 |
+
mark_shared_params=False,
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
This Block has a slightly different structure compared to a regular
|
| 350 |
+
prenorm Transformer block.
|
| 351 |
+
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
| 352 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 353 |
+
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
| 354 |
+
the hidden_states (output1 of the MHA / MLP) and the residual.
|
| 355 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
| 356 |
+
The residual needs to be provided (except for the very first block).
|
| 357 |
+
"""
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.tied_norm = tied_norm
|
| 360 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 361 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 362 |
+
if mixer_cls is None:
|
| 363 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
| 364 |
+
if mlp_cls is None:
|
| 365 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
| 366 |
+
self.mixer = mixer_cls(dim)
|
| 367 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
| 368 |
+
self.norm1 = norm_cls(dim)
|
| 369 |
+
self.mlp = mlp_cls(dim)
|
| 370 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
| 371 |
+
if not self.tied_norm:
|
| 372 |
+
self.norm2 = norm_cls(dim)
|
| 373 |
+
|
| 374 |
+
if self.fused_dropout_add_ln:
|
| 375 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
| 376 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
| 377 |
+
self.dropout1, nn.Dropout
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
| 381 |
+
# then the input to each worker in the tensor parallel group will be different.
|
| 382 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
| 383 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
| 384 |
+
# and only use sequence_parallel=False during inference.
|
| 385 |
+
|
| 386 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
| 387 |
+
if sequence_parallel:
|
| 388 |
+
for p in self.norm1.parameters():
|
| 389 |
+
p._sequence_parallel = True
|
| 390 |
+
if hasattr(self, "norm2"):
|
| 391 |
+
for p in self.norm2.parameters():
|
| 392 |
+
p._sequence_parallel = True
|
| 393 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 394 |
+
if mark_shared_params:
|
| 395 |
+
for p in self.norm1.parameters():
|
| 396 |
+
p._shared_params = True
|
| 397 |
+
if hasattr(self, "norm2"):
|
| 398 |
+
for p in self.norm2.parameters():
|
| 399 |
+
p._shared_params = True
|
| 400 |
+
|
| 401 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 402 |
+
return self.mixer.allocate_inference_cache(
|
| 403 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states1: Tensor,
|
| 409 |
+
hidden_states2: Optional[Tensor] = None,
|
| 410 |
+
residual: Optional[Tensor] = None,
|
| 411 |
+
mixer_kwargs=None,
|
| 412 |
+
):
|
| 413 |
+
r"""Pass the input through the encoder layer.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
| 417 |
+
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
| 418 |
+
residual.
|
| 419 |
+
"""
|
| 420 |
+
# TODO: Ideally we should only do the allgather / allreduce once for
|
| 421 |
+
# the Linear to MLP & Attention
|
| 422 |
+
if not self.fused_dropout_add_ln:
|
| 423 |
+
dropped1 = self.dropout1(hidden_states1)
|
| 424 |
+
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
| 425 |
+
if hidden_states2 is not None:
|
| 426 |
+
dropped2 = self.dropout2(hidden_states2)
|
| 427 |
+
residual = (
|
| 428 |
+
(residual + dropped1 + dropped2)
|
| 429 |
+
if residual is not None
|
| 430 |
+
else dropped1 + dropped2
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
residual = (residual + dropped1) if residual is not None else dropped1
|
| 434 |
+
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 435 |
+
hidden_states2 = (
|
| 436 |
+
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 437 |
+
if not self.tied_norm
|
| 438 |
+
else hidden_states1
|
| 439 |
+
)
|
| 440 |
+
if self.residual_in_fp32:
|
| 441 |
+
residual = residual.to(torch.float32)
|
| 442 |
+
else:
|
| 443 |
+
weight2, bias2 = (
|
| 444 |
+
(self.norm2.weight, self.norm2.bias)
|
| 445 |
+
if not self.tied_norm
|
| 446 |
+
else (None, None)
|
| 447 |
+
)
|
| 448 |
+
hidden_states1, *rest, residual = layer_norm_fn(
|
| 449 |
+
hidden_states1,
|
| 450 |
+
self.norm1.weight,
|
| 451 |
+
self.norm1.bias,
|
| 452 |
+
residual=residual,
|
| 453 |
+
x1=hidden_states2,
|
| 454 |
+
weight1=weight2,
|
| 455 |
+
bias1=bias2,
|
| 456 |
+
eps=self.norm1.eps,
|
| 457 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 458 |
+
prenorm=True,
|
| 459 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 460 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 461 |
+
)
|
| 462 |
+
if self.tied_norm:
|
| 463 |
+
hidden_states2 = hidden_states1
|
| 464 |
+
else:
|
| 465 |
+
(hidden_states2,) = rest
|
| 466 |
+
if mixer_kwargs is None:
|
| 467 |
+
mixer_kwargs = {}
|
| 468 |
+
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
| 469 |
+
hidden_states2 = self.mlp(hidden_states2)
|
| 470 |
+
return hidden_states1, hidden_states2, residual
|
config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"XLMRobertaForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
|
| 8 |
+
"AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
|
| 9 |
+
"AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 0,
|
| 12 |
+
"classifier_dropout": null,
|
| 13 |
+
"dtype": "bfloat16",
|
| 14 |
+
"emb_pooler": null,
|
| 15 |
+
"eos_token_id": 2,
|
| 16 |
+
"hidden_act": "gelu",
|
| 17 |
+
"hidden_dropout_prob": 0.1,
|
| 18 |
+
"hidden_size": 768,
|
| 19 |
+
"id2label": {
|
| 20 |
+
"0": "LABEL_0"
|
| 21 |
+
},
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 3072,
|
| 24 |
+
"label2id": {
|
| 25 |
+
"LABEL_0": 0
|
| 26 |
+
},
|
| 27 |
+
"layer_norm_eps": 1e-05,
|
| 28 |
+
"load_trained_adapters": false,
|
| 29 |
+
"lora_adaptations": null,
|
| 30 |
+
"lora_alpha": 1,
|
| 31 |
+
"lora_dropout_p": 0.0,
|
| 32 |
+
"lora_main_params_trainable": false,
|
| 33 |
+
"lora_rank": 4,
|
| 34 |
+
"matryoshka_dimensions": null,
|
| 35 |
+
"max_position_embeddings": 1026,
|
| 36 |
+
"num_attention_heads": 12,
|
| 37 |
+
"num_hidden_layers": 12,
|
| 38 |
+
"output_past": true,
|
| 39 |
+
"pad_token_id": 1,
|
| 40 |
+
"position_embedding_type": "absolute",
|
| 41 |
+
"sentence_transformers": {
|
| 42 |
+
"activation_fn": "torch.nn.modules.activation.Sigmoid",
|
| 43 |
+
"version": "5.1.0"
|
| 44 |
+
},
|
| 45 |
+
"transformers_version": "4.56.1",
|
| 46 |
+
"truncate_dim": null,
|
| 47 |
+
"type_vocab_size": 1,
|
| 48 |
+
"use_cache": false,
|
| 49 |
+
"use_flash_attn": true,
|
| 50 |
+
"vocab_size": 250002
|
| 51 |
+
}
|
configuration_xlm_roberta.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
vocab_size=30522,
|
| 8 |
+
hidden_size=768,
|
| 9 |
+
num_hidden_layers=12,
|
| 10 |
+
num_attention_heads=12,
|
| 11 |
+
intermediate_size=3072,
|
| 12 |
+
hidden_act="gelu",
|
| 13 |
+
hidden_dropout_prob=0.1,
|
| 14 |
+
attention_probs_dropout_prob=0.1,
|
| 15 |
+
max_position_embeddings=512,
|
| 16 |
+
type_vocab_size=2,
|
| 17 |
+
initializer_range=0.02,
|
| 18 |
+
layer_norm_eps=1e-12,
|
| 19 |
+
pad_token_id=1,
|
| 20 |
+
bos_token_id=0,
|
| 21 |
+
eos_token_id=2,
|
| 22 |
+
position_embedding_type="absolute",
|
| 23 |
+
use_cache=True,
|
| 24 |
+
classifier_dropout=None,
|
| 25 |
+
lora_adaptations=None,
|
| 26 |
+
lora_rank=4,
|
| 27 |
+
lora_dropout_p=0.0,
|
| 28 |
+
lora_alpha=1,
|
| 29 |
+
lora_main_params_trainable=False,
|
| 30 |
+
load_trained_adapters=False,
|
| 31 |
+
use_flash_attn=True,
|
| 32 |
+
torch_dtype=None,
|
| 33 |
+
emb_pooler=None,
|
| 34 |
+
matryoshka_dimensions=None,
|
| 35 |
+
truncate_dim=None,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
self.vocab_size = vocab_size
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.num_hidden_layers = num_hidden_layers
|
| 44 |
+
self.num_attention_heads = num_attention_heads
|
| 45 |
+
self.hidden_act = hidden_act
|
| 46 |
+
self.intermediate_size = intermediate_size
|
| 47 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 48 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 49 |
+
self.max_position_embeddings = max_position_embeddings
|
| 50 |
+
self.type_vocab_size = type_vocab_size
|
| 51 |
+
self.initializer_range = initializer_range
|
| 52 |
+
self.layer_norm_eps = layer_norm_eps
|
| 53 |
+
self.position_embedding_type = position_embedding_type
|
| 54 |
+
self.use_cache = use_cache
|
| 55 |
+
self.classifier_dropout = classifier_dropout
|
| 56 |
+
self.load_trained_adapters = load_trained_adapters
|
| 57 |
+
self.lora_adaptations = lora_adaptations
|
| 58 |
+
self.lora_rank = lora_rank
|
| 59 |
+
self.lora_dropout_p = lora_dropout_p
|
| 60 |
+
self.lora_alpha = lora_alpha
|
| 61 |
+
self.lora_main_params_trainable = lora_main_params_trainable
|
| 62 |
+
self.use_flash_attn = use_flash_attn
|
| 63 |
+
self.emb_pooler = emb_pooler
|
| 64 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
| 65 |
+
self.truncate_dim = truncate_dim
|
| 66 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 67 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
| 68 |
+
else:
|
| 69 |
+
self.torch_dtype = torch_dtype
|
embedding.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
|
| 2 |
+
# Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2022, Tri Dao.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class XLMRobertaEmbeddings(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
embed_dim,
|
| 18 |
+
vocab_size,
|
| 19 |
+
max_position_embeddings,
|
| 20 |
+
type_vocab_size,
|
| 21 |
+
padding_idx=None,
|
| 22 |
+
device=None,
|
| 23 |
+
dtype=None,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
| 27 |
+
If type_vocab_size <= 0, there's no token type embeddings
|
| 28 |
+
"""
|
| 29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.word_embeddings = nn.Embedding(
|
| 32 |
+
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
|
| 33 |
+
)
|
| 34 |
+
self.max_position_embeddings = max_position_embeddings
|
| 35 |
+
self.type_vocab_size = type_vocab_size
|
| 36 |
+
if self.max_position_embeddings > 0:
|
| 37 |
+
self.position_embeddings = nn.Embedding(
|
| 38 |
+
max_position_embeddings, embed_dim, **factory_kwargs
|
| 39 |
+
)
|
| 40 |
+
if self.type_vocab_size > 0:
|
| 41 |
+
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
| 42 |
+
|
| 43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
| 44 |
+
"""
|
| 45 |
+
input_ids: (batch, seqlen)
|
| 46 |
+
position_ids: (batch, seqlen)
|
| 47 |
+
token_type_ids: (batch, seqlen)
|
| 48 |
+
"""
|
| 49 |
+
batch_size, seqlen = input_ids.shape
|
| 50 |
+
embeddings = self.word_embeddings(input_ids)
|
| 51 |
+
if self.max_position_embeddings > 0:
|
| 52 |
+
if position_ids is None:
|
| 53 |
+
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
| 54 |
+
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
| 55 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 56 |
+
embeddings = embeddings + position_embeddings
|
| 57 |
+
if self.type_vocab_size > 0:
|
| 58 |
+
if token_type_ids is None:
|
| 59 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 61 |
+
embeddings = embeddings + token_type_embeddings
|
| 62 |
+
return embeddings
|
mha.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
# Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import (
|
| 13 |
+
flash_attn_kvpacked_func,
|
| 14 |
+
flash_attn_qkvpacked_func,
|
| 15 |
+
flash_attn_varlen_kvpacked_func,
|
| 16 |
+
flash_attn_varlen_qkvpacked_func,
|
| 17 |
+
flash_attn_with_kvcache,
|
| 18 |
+
)
|
| 19 |
+
except ImportError:
|
| 20 |
+
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
| 21 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
| 22 |
+
flash_attn_with_kvcache = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
|
| 26 |
+
except ImportError:
|
| 27 |
+
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FlashSelfAttention(nn.Module):
|
| 31 |
+
"""Implement the scaled dot product attention with softmax.
|
| 32 |
+
Arguments
|
| 33 |
+
---------
|
| 34 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 35 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 36 |
+
runtime)
|
| 37 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 38 |
+
(default: 0.0)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
causal=False,
|
| 44 |
+
softmax_scale=None,
|
| 45 |
+
attention_dropout=0.0,
|
| 46 |
+
window_size=(-1, -1),
|
| 47 |
+
deterministic=False,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 51 |
+
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 52 |
+
self.causal = causal
|
| 53 |
+
self.softmax_scale = softmax_scale
|
| 54 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 55 |
+
self.window_size = window_size
|
| 56 |
+
self.deterministic = deterministic
|
| 57 |
+
|
| 58 |
+
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
| 59 |
+
"""Implements the multihead softmax attention.
|
| 60 |
+
Arguments
|
| 61 |
+
---------
|
| 62 |
+
qkv: The tensor containing the query, key, and value.
|
| 63 |
+
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
| 64 |
+
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
| 65 |
+
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
| 66 |
+
causal: if passed, will override self.causal
|
| 67 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 68 |
+
of the sequences in the batch, used to index into qkv.
|
| 69 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 70 |
+
Returns:
|
| 71 |
+
--------
|
| 72 |
+
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
| 73 |
+
else (B, S, H, D).
|
| 74 |
+
"""
|
| 75 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 76 |
+
assert qkv.is_cuda
|
| 77 |
+
causal = self.causal if causal is None else causal
|
| 78 |
+
unpadded = cu_seqlens is not None
|
| 79 |
+
|
| 80 |
+
if unpadded:
|
| 81 |
+
assert cu_seqlens.dtype == torch.int32
|
| 82 |
+
assert max_seqlen is not None
|
| 83 |
+
assert isinstance(max_seqlen, int)
|
| 84 |
+
return flash_attn_varlen_qkvpacked_func(
|
| 85 |
+
qkv,
|
| 86 |
+
cu_seqlens,
|
| 87 |
+
max_seqlen,
|
| 88 |
+
self.drop.p if self.training else 0.0,
|
| 89 |
+
softmax_scale=self.softmax_scale,
|
| 90 |
+
causal=causal,
|
| 91 |
+
alibi_slopes=None,
|
| 92 |
+
window_size=self.window_size,
|
| 93 |
+
deterministic=self.deterministic,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
return flash_attn_qkvpacked_func(
|
| 97 |
+
qkv,
|
| 98 |
+
self.drop.p if self.training else 0.0,
|
| 99 |
+
softmax_scale=self.softmax_scale,
|
| 100 |
+
causal=causal,
|
| 101 |
+
alibi_slopes=None,
|
| 102 |
+
window_size=self.window_size,
|
| 103 |
+
deterministic=self.deterministic,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FlashCrossAttention(nn.Module):
|
| 108 |
+
"""Implement the scaled dot product attention with softmax.
|
| 109 |
+
Arguments
|
| 110 |
+
---------
|
| 111 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 112 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 113 |
+
runtime)
|
| 114 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 115 |
+
(default: 0.0)
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
causal=False,
|
| 121 |
+
softmax_scale=None,
|
| 122 |
+
attention_dropout=0.0,
|
| 123 |
+
window_size=(-1, -1),
|
| 124 |
+
deterministic=False,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
| 128 |
+
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
| 129 |
+
self.causal = causal
|
| 130 |
+
self.softmax_scale = softmax_scale
|
| 131 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 132 |
+
self.window_size = window_size
|
| 133 |
+
self.deterministic = deterministic
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
q,
|
| 138 |
+
kv,
|
| 139 |
+
causal=None,
|
| 140 |
+
cu_seqlens=None,
|
| 141 |
+
max_seqlen=None,
|
| 142 |
+
cu_seqlens_k=None,
|
| 143 |
+
max_seqlen_k=None,
|
| 144 |
+
):
|
| 145 |
+
"""Implements the multihead softmax attention.
|
| 146 |
+
Arguments
|
| 147 |
+
---------
|
| 148 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 149 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 150 |
+
causal: if passed, will override self.causal
|
| 151 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 152 |
+
of the sequences in the batch, used to index into q.
|
| 153 |
+
max_seqlen: int. Maximum sequence length in the batch of q.
|
| 154 |
+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 155 |
+
of the sequences in the batch, used to index into kv.
|
| 156 |
+
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
| 157 |
+
"""
|
| 158 |
+
assert q.dtype in [torch.float16, torch.bfloat16]
|
| 159 |
+
assert q.is_cuda and kv.is_cuda
|
| 160 |
+
causal = self.causal if causal is None else causal
|
| 161 |
+
unpadded = cu_seqlens is not None
|
| 162 |
+
|
| 163 |
+
if unpadded:
|
| 164 |
+
assert cu_seqlens.dtype == torch.int32
|
| 165 |
+
assert max_seqlen is not None
|
| 166 |
+
assert isinstance(max_seqlen, int)
|
| 167 |
+
assert cu_seqlens_k is not None
|
| 168 |
+
assert cu_seqlens_k.dtype == torch.int32
|
| 169 |
+
assert max_seqlen_k is not None
|
| 170 |
+
assert isinstance(max_seqlen, int)
|
| 171 |
+
return flash_attn_varlen_kvpacked_func(
|
| 172 |
+
q,
|
| 173 |
+
kv,
|
| 174 |
+
cu_seqlens,
|
| 175 |
+
cu_seqlens_k,
|
| 176 |
+
max_seqlen,
|
| 177 |
+
max_seqlen_k,
|
| 178 |
+
self.drop.p if self.training else 0.0,
|
| 179 |
+
softmax_scale=self.softmax_scale,
|
| 180 |
+
causal=causal,
|
| 181 |
+
alibi_slopes=None,
|
| 182 |
+
window_size=self.window_size,
|
| 183 |
+
deterministic=self.deterministic,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 187 |
+
seqlen_k = kv.shape[1]
|
| 188 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 189 |
+
return flash_attn_kvpacked_func(
|
| 190 |
+
q,
|
| 191 |
+
kv,
|
| 192 |
+
self.drop.p if self.training else 0.0,
|
| 193 |
+
causal=causal,
|
| 194 |
+
softmax_scale=self.softmax_scale,
|
| 195 |
+
alibi_slopes=None,
|
| 196 |
+
window_size=self.window_size,
|
| 197 |
+
deterministic=self.deterministic,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SelfAttention(nn.Module):
|
| 202 |
+
"""Implement the scaled dot product attention with softmax.
|
| 203 |
+
Arguments
|
| 204 |
+
---------
|
| 205 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 206 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 207 |
+
runtime)
|
| 208 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 209 |
+
(default: 0.0)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.causal = causal
|
| 215 |
+
self.softmax_scale = softmax_scale
|
| 216 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 217 |
+
|
| 218 |
+
def forward(self, qkv, causal=None, key_padding_mask=None):
|
| 219 |
+
"""Implements the multihead softmax attention.
|
| 220 |
+
Arguments
|
| 221 |
+
---------
|
| 222 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
| 223 |
+
causal: if passed, will override self.causal
|
| 224 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 225 |
+
False means to mask out. (B, S)
|
| 226 |
+
"""
|
| 227 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 228 |
+
causal = self.causal if causal is None else causal
|
| 229 |
+
q, k, v = qkv.unbind(dim=2)
|
| 230 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 231 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 232 |
+
if key_padding_mask is not None:
|
| 233 |
+
padding_mask = torch.full(
|
| 234 |
+
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
|
| 235 |
+
)
|
| 236 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 237 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 238 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 239 |
+
if causal:
|
| 240 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
| 241 |
+
# So we have to construct the mask in float
|
| 242 |
+
causal_mask = torch.triu(
|
| 243 |
+
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
|
| 244 |
+
)
|
| 245 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 246 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 247 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
| 248 |
+
attention_drop = self.drop(attention)
|
| 249 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
| 250 |
+
return output
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CrossAttention(nn.Module):
|
| 254 |
+
"""Implement the scaled dot product attention with softmax.
|
| 255 |
+
Arguments
|
| 256 |
+
---------
|
| 257 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 258 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 259 |
+
runtime)
|
| 260 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 261 |
+
(default: 0.0)
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.causal = causal
|
| 267 |
+
self.softmax_scale = softmax_scale
|
| 268 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 269 |
+
|
| 270 |
+
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
| 271 |
+
"""Implements the multihead softmax attention.
|
| 272 |
+
Arguments
|
| 273 |
+
---------
|
| 274 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 275 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 276 |
+
causal: if passed, will override self.causal
|
| 277 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 278 |
+
False means to mask out. (B, Sk)
|
| 279 |
+
"""
|
| 280 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 281 |
+
causal = self.causal if causal is None else causal
|
| 282 |
+
seqlen_k = kv.shape[1]
|
| 283 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 284 |
+
if kv.shape[3] != q.shape[2]: # MQA/GQA
|
| 285 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 286 |
+
k, v = kv.unbind(dim=2)
|
| 287 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 288 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 289 |
+
if key_padding_mask is not None:
|
| 290 |
+
padding_mask = torch.full(
|
| 291 |
+
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
|
| 292 |
+
)
|
| 293 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 294 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 295 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 296 |
+
if causal:
|
| 297 |
+
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
|
| 298 |
+
row_idx = rearrange(
|
| 299 |
+
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
|
| 300 |
+
)
|
| 301 |
+
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
|
| 302 |
+
sk = (
|
| 303 |
+
seqlen_k
|
| 304 |
+
if key_padding_mask is None
|
| 305 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 306 |
+
)
|
| 307 |
+
causal_mask = col_idx > row_idx + sk - seqlen_q
|
| 308 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 309 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
| 310 |
+
attention_drop = self.drop(attention)
|
| 311 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
| 312 |
+
return output
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class LinearResidual(nn.Linear):
|
| 316 |
+
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
| 317 |
+
|
| 318 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 319 |
+
return super().forward(input), input
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 323 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 324 |
+
# Pre-allocate memory for key-values for inference.
|
| 325 |
+
num_heads, head_dim = kv.shape[-2:]
|
| 326 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
| 327 |
+
kv_cache = torch.empty(
|
| 328 |
+
inference_params.max_batch_size,
|
| 329 |
+
inference_params.max_seqlen,
|
| 330 |
+
2,
|
| 331 |
+
num_heads,
|
| 332 |
+
head_dim,
|
| 333 |
+
dtype=kv.dtype,
|
| 334 |
+
device=kv.device,
|
| 335 |
+
)
|
| 336 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 337 |
+
else:
|
| 338 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 339 |
+
# Adjust key and value for inference
|
| 340 |
+
batch_start = inference_params.batch_size_offset
|
| 341 |
+
batch_end = batch_start + kv.shape[0]
|
| 342 |
+
sequence_start = inference_params.seqlen_offset
|
| 343 |
+
sequence_end = sequence_start + kv.shape[1]
|
| 344 |
+
assert batch_end <= kv_cache.shape[0]
|
| 345 |
+
assert sequence_end <= kv_cache.shape[1]
|
| 346 |
+
assert kv_cache is not None
|
| 347 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 348 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MHA(nn.Module):
|
| 352 |
+
"""Multi-head self-attention and cross-attention"""
|
| 353 |
+
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
embed_dim,
|
| 357 |
+
num_heads,
|
| 358 |
+
num_heads_kv=None,
|
| 359 |
+
cross_attn=False,
|
| 360 |
+
qkv_proj_bias=True,
|
| 361 |
+
out_proj_bias=True,
|
| 362 |
+
dropout=0.0,
|
| 363 |
+
softmax_scale=None,
|
| 364 |
+
causal=False,
|
| 365 |
+
layer_idx=None,
|
| 366 |
+
dwconv=False,
|
| 367 |
+
window_size=(-1, -1),
|
| 368 |
+
fused_bias_fc=False,
|
| 369 |
+
use_flash_attn=False,
|
| 370 |
+
return_residual=False,
|
| 371 |
+
checkpointing=False,
|
| 372 |
+
device=None,
|
| 373 |
+
dtype=None,
|
| 374 |
+
) -> None:
|
| 375 |
+
"""
|
| 376 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 377 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 378 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 379 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 380 |
+
"""
|
| 381 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.embed_dim = embed_dim
|
| 384 |
+
self.cross_attn = cross_attn
|
| 385 |
+
self.causal = causal
|
| 386 |
+
self.layer_idx = layer_idx
|
| 387 |
+
self.dwconv = dwconv
|
| 388 |
+
self.use_flash_attn = use_flash_attn
|
| 389 |
+
self.return_residual = return_residual
|
| 390 |
+
self.checkpointing = checkpointing
|
| 391 |
+
|
| 392 |
+
if window_size != (-1, -1):
|
| 393 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 394 |
+
|
| 395 |
+
self.num_heads = num_heads
|
| 396 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 397 |
+
assert (
|
| 398 |
+
self.num_heads % self.num_heads_kv == 0
|
| 399 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 400 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 401 |
+
self.head_dim = self.embed_dim // num_heads
|
| 402 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 403 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 404 |
+
|
| 405 |
+
if fused_bias_fc and FusedDense is None:
|
| 406 |
+
raise ImportError("fused_dense is not installed")
|
| 407 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 408 |
+
linear_resid_cls = (
|
| 409 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
| 410 |
+
)
|
| 411 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 412 |
+
inner_attn_cls = (
|
| 413 |
+
partial(FlashSelfAttention, window_size=window_size)
|
| 414 |
+
if use_flash_attn
|
| 415 |
+
else SelfAttention
|
| 416 |
+
)
|
| 417 |
+
inner_cross_attn_cls = (
|
| 418 |
+
partial(FlashCrossAttention, window_size=window_size)
|
| 419 |
+
if use_flash_attn
|
| 420 |
+
else CrossAttention
|
| 421 |
+
)
|
| 422 |
+
if not self.cross_attn:
|
| 423 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 424 |
+
else:
|
| 425 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 426 |
+
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 427 |
+
if self.dwconv:
|
| 428 |
+
if self.num_heads_kv == self.num_heads:
|
| 429 |
+
self.dwconv_qkv = nn.Conv1d(
|
| 430 |
+
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
self.dwconv_q = nn.Conv1d(
|
| 434 |
+
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
| 435 |
+
)
|
| 436 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
| 437 |
+
self.inner_attn = inner_attn_cls(
|
| 438 |
+
causal=causal,
|
| 439 |
+
softmax_scale=softmax_scale,
|
| 440 |
+
attention_dropout=dropout,
|
| 441 |
+
)
|
| 442 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 443 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 444 |
+
)
|
| 445 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
| 446 |
+
|
| 447 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 448 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 449 |
+
device = self.out_proj.weight.device
|
| 450 |
+
return torch.empty(
|
| 451 |
+
batch_size,
|
| 452 |
+
max_seqlen,
|
| 453 |
+
2,
|
| 454 |
+
self.num_heads_kv,
|
| 455 |
+
self.head_dim,
|
| 456 |
+
dtype=dtype,
|
| 457 |
+
device=device,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 461 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 462 |
+
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 463 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 464 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 465 |
+
|
| 466 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 467 |
+
"""
|
| 468 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 469 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 470 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 471 |
+
"""
|
| 472 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 473 |
+
assert self.use_flash_attn
|
| 474 |
+
batch = q.shape[0]
|
| 475 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 476 |
+
cache_seqlens = (
|
| 477 |
+
inference_params.lengths_per_sample[:batch]
|
| 478 |
+
if inference_params.lengths_per_sample is not None
|
| 479 |
+
else inference_params.seqlen_offset
|
| 480 |
+
)
|
| 481 |
+
context = flash_attn_with_kvcache(
|
| 482 |
+
q,
|
| 483 |
+
kv_cache[:, :, 0],
|
| 484 |
+
kv_cache[:, :, 1],
|
| 485 |
+
kv[:, :, 0],
|
| 486 |
+
kv[:, :, 1],
|
| 487 |
+
cache_seqlens=cache_seqlens,
|
| 488 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 489 |
+
causal=self.inner_cross_attn.causal,
|
| 490 |
+
rotary_interleaved=False,
|
| 491 |
+
alibi_slopes=None,
|
| 492 |
+
)
|
| 493 |
+
return context
|
| 494 |
+
|
| 495 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 496 |
+
"""Write kv to inference_params, then do attention"""
|
| 497 |
+
if (
|
| 498 |
+
inference_params.seqlen_offset == 0
|
| 499 |
+
or flash_attn_with_kvcache is None
|
| 500 |
+
or not self.use_flash_attn
|
| 501 |
+
):
|
| 502 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 503 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 504 |
+
return self.inner_cross_attn(q, kv)
|
| 505 |
+
else:
|
| 506 |
+
batch = q.shape[0]
|
| 507 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 508 |
+
cache_seqlens = (
|
| 509 |
+
inference_params.lengths_per_sample[:batch]
|
| 510 |
+
if inference_params.lengths_per_sample is not None
|
| 511 |
+
else inference_params.seqlen_offset
|
| 512 |
+
)
|
| 513 |
+
return flash_attn_with_kvcache(
|
| 514 |
+
q,
|
| 515 |
+
kv_cache[:, :, 0],
|
| 516 |
+
kv_cache[:, :, 1],
|
| 517 |
+
kv[:, :, 0],
|
| 518 |
+
kv[:, :, 1],
|
| 519 |
+
cache_seqlens=cache_seqlens,
|
| 520 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 521 |
+
causal=self.inner_cross_attn.causal,
|
| 522 |
+
alibi_slopes=None,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
def forward(
|
| 526 |
+
self,
|
| 527 |
+
x,
|
| 528 |
+
x_kv=None,
|
| 529 |
+
key_padding_mask=None,
|
| 530 |
+
cu_seqlens=None,
|
| 531 |
+
max_seqlen=None,
|
| 532 |
+
mixer_subset=None,
|
| 533 |
+
inference_params=None,
|
| 534 |
+
**kwargs,
|
| 535 |
+
):
|
| 536 |
+
"""
|
| 537 |
+
Arguments:
|
| 538 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
| 539 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
| 540 |
+
is the is the sum of the sequence lengths in the batch.
|
| 541 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
| 542 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 543 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
| 544 |
+
FlashAttention.
|
| 545 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 546 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
| 547 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
| 548 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 549 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 550 |
+
about the CLS token in the last layer.
|
| 551 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 552 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 553 |
+
"""
|
| 554 |
+
if cu_seqlens is not None:
|
| 555 |
+
assert max_seqlen is not None
|
| 556 |
+
assert key_padding_mask is None
|
| 557 |
+
assert self.use_flash_attn
|
| 558 |
+
assert not self.dwconv
|
| 559 |
+
if key_padding_mask is not None:
|
| 560 |
+
assert cu_seqlens is None
|
| 561 |
+
assert max_seqlen is None
|
| 562 |
+
assert not self.use_flash_attn
|
| 563 |
+
if inference_params is not None:
|
| 564 |
+
assert key_padding_mask is None
|
| 565 |
+
assert cu_seqlens is None and max_seqlen is None
|
| 566 |
+
assert not self.dwconv
|
| 567 |
+
|
| 568 |
+
kwargs = (
|
| 569 |
+
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
| 570 |
+
if self.use_flash_attn
|
| 571 |
+
else {"key_padding_mask": key_padding_mask, **kwargs}
|
| 572 |
+
)
|
| 573 |
+
seqlen_offset = (
|
| 574 |
+
0
|
| 575 |
+
if inference_params is None
|
| 576 |
+
else (
|
| 577 |
+
inference_params.lengths_per_sample
|
| 578 |
+
if inference_params.lengths_per_sample is not None
|
| 579 |
+
else inference_params.seqlen_offset
|
| 580 |
+
)
|
| 581 |
+
)
|
| 582 |
+
rotary_max_seqlen = (
|
| 583 |
+
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
| 584 |
+
)
|
| 585 |
+
batch, seqlen = x.shape[:2]
|
| 586 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 587 |
+
assert x_kv is None and mixer_subset is None
|
| 588 |
+
if not self.return_residual:
|
| 589 |
+
qkv = self.Wqkv(x)
|
| 590 |
+
else:
|
| 591 |
+
qkv, x = self.Wqkv(x)
|
| 592 |
+
if self.dwconv:
|
| 593 |
+
qkv = rearrange(
|
| 594 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 595 |
+
).contiguous()
|
| 596 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 597 |
+
if (
|
| 598 |
+
inference_params is None
|
| 599 |
+
or inference_params.seqlen_offset == 0
|
| 600 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 601 |
+
or not self.use_flash_attn
|
| 602 |
+
):
|
| 603 |
+
if inference_params is None:
|
| 604 |
+
if not self.checkpointing:
|
| 605 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 606 |
+
else:
|
| 607 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 608 |
+
else:
|
| 609 |
+
context = self._update_kvcache_attention(
|
| 610 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
context = self._apply_rotary_update_kvcache_attention(
|
| 614 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 615 |
+
)
|
| 616 |
+
else:
|
| 617 |
+
if self.cross_attn:
|
| 618 |
+
if not self.return_residual:
|
| 619 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 620 |
+
kv = self.Wkv(x_kv if x_kv is not None else x)
|
| 621 |
+
else:
|
| 622 |
+
if x_kv is not None:
|
| 623 |
+
kv, x_kv = self.Wkv(x_kv)
|
| 624 |
+
else:
|
| 625 |
+
kv, x = self.Wkv(x)
|
| 626 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 627 |
+
else:
|
| 628 |
+
assert self.num_heads_kv != self.num_heads
|
| 629 |
+
if not self.return_residual:
|
| 630 |
+
qkv = self.Wqkv(x)
|
| 631 |
+
else:
|
| 632 |
+
qkv, x = self.Wqkv(x)
|
| 633 |
+
q = qkv[..., : self.num_heads * self.head_dim]
|
| 634 |
+
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 635 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 636 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
| 637 |
+
if self.dwconv:
|
| 638 |
+
q = rearrange(
|
| 639 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 640 |
+
).contiguous()
|
| 641 |
+
kv = rearrange(
|
| 642 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 643 |
+
).contiguous()
|
| 644 |
+
if (
|
| 645 |
+
inference_params is None
|
| 646 |
+
or inference_params.seqlen_offset == 0
|
| 647 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 648 |
+
or not self.use_flash_attn
|
| 649 |
+
):
|
| 650 |
+
if inference_params is None:
|
| 651 |
+
if not self.checkpointing:
|
| 652 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 653 |
+
else:
|
| 654 |
+
context = torch.utils.checkpoint.checkpoint(
|
| 655 |
+
self.inner_cross_attn, q, kv, **kwargs
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 659 |
+
else:
|
| 660 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 661 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 662 |
+
return out if not self.return_residual else (out, x)
|
mlp.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
|
| 2 |
+
# Commit id: c3b219665292c61a51153d0ded4473c494296382
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2023, Tri Dao.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.distributed import ProcessGroup
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from flash_attn.ops.activations import swiglu
|
| 14 |
+
except ImportError:
|
| 15 |
+
swiglu = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
| 19 |
+
except ImportError:
|
| 20 |
+
ColumnParallelLinear, RowParallelLinear = None, None
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
| 24 |
+
except ImportError:
|
| 25 |
+
FusedMLP, ParallelFusedMLP = None, None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Mlp(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
in_features,
|
| 32 |
+
hidden_features=None,
|
| 33 |
+
out_features=None,
|
| 34 |
+
activation=F.gelu,
|
| 35 |
+
bias1=True,
|
| 36 |
+
bias2=True,
|
| 37 |
+
return_residual=False,
|
| 38 |
+
device=None,
|
| 39 |
+
dtype=None,
|
| 40 |
+
):
|
| 41 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 42 |
+
super().__init__()
|
| 43 |
+
out_features = out_features if out_features is not None else in_features
|
| 44 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
| 45 |
+
self.return_residual = return_residual
|
| 46 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 47 |
+
self.activation = activation
|
| 48 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
y = self.fc1(x)
|
| 52 |
+
y = self.activation(y)
|
| 53 |
+
y = self.fc2(y)
|
| 54 |
+
return y if not self.return_residual else (y, x)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ParallelMLP(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_features,
|
| 61 |
+
hidden_features=None,
|
| 62 |
+
out_features=None,
|
| 63 |
+
activation=F.gelu,
|
| 64 |
+
process_group: ProcessGroup = None,
|
| 65 |
+
sequence_parallel=True,
|
| 66 |
+
bias1=True,
|
| 67 |
+
bias2=True,
|
| 68 |
+
device=None,
|
| 69 |
+
dtype=None,
|
| 70 |
+
):
|
| 71 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 72 |
+
super().__init__()
|
| 73 |
+
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
| 74 |
+
assert RowParallelLinear is not None, "Need to install fused_dense"
|
| 75 |
+
out_features = out_features if out_features is not None else in_features
|
| 76 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
| 77 |
+
self.fc1 = ColumnParallelLinear(
|
| 78 |
+
in_features,
|
| 79 |
+
hidden_features,
|
| 80 |
+
process_group,
|
| 81 |
+
bias=bias1,
|
| 82 |
+
sequence_parallel=sequence_parallel,
|
| 83 |
+
**factory_kwargs,
|
| 84 |
+
)
|
| 85 |
+
self.activation = activation
|
| 86 |
+
self.fc2 = RowParallelLinear(
|
| 87 |
+
hidden_features,
|
| 88 |
+
out_features,
|
| 89 |
+
process_group,
|
| 90 |
+
bias=bias2,
|
| 91 |
+
sequence_parallel=sequence_parallel,
|
| 92 |
+
**factory_kwargs,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
y = self.fc1(x)
|
| 97 |
+
y = self.activation(y)
|
| 98 |
+
y = self.fc2(y)
|
| 99 |
+
return y
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GatedMlp(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
in_features,
|
| 106 |
+
hidden_features=None,
|
| 107 |
+
out_features=None,
|
| 108 |
+
activation=F.sigmoid,
|
| 109 |
+
bias1=True,
|
| 110 |
+
bias2=True,
|
| 111 |
+
multiple_of=128,
|
| 112 |
+
return_residual=False,
|
| 113 |
+
device=None,
|
| 114 |
+
dtype=None,
|
| 115 |
+
):
|
| 116 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 117 |
+
super().__init__()
|
| 118 |
+
out_features = out_features if out_features is not None else in_features
|
| 119 |
+
hidden_features = (
|
| 120 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 121 |
+
)
|
| 122 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 123 |
+
self.return_residual = return_residual
|
| 124 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
| 125 |
+
self.activation = activation
|
| 126 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
y = self.fc1(x)
|
| 130 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
| 131 |
+
y = F.glu(y, dim=-1)
|
| 132 |
+
elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
|
| 133 |
+
y, gate = y.chunk(2, dim=-1)
|
| 134 |
+
y = swiglu(gate, y)
|
| 135 |
+
else:
|
| 136 |
+
y, gate = y.chunk(2, dim=-1)
|
| 137 |
+
y = y * self.activation(gate)
|
| 138 |
+
y = self.fc2(y)
|
| 139 |
+
return y if not self.return_residual else (y, x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ParallelGatedMlp(nn.Module):
|
| 143 |
+
"""Parallel GatedMlp"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_features,
|
| 148 |
+
process_group,
|
| 149 |
+
hidden_features=None,
|
| 150 |
+
out_features=None,
|
| 151 |
+
activation=F.sigmoid,
|
| 152 |
+
bias1=True,
|
| 153 |
+
bias2=True,
|
| 154 |
+
multiple_of=128,
|
| 155 |
+
sequence_parallel=True,
|
| 156 |
+
device=None,
|
| 157 |
+
dtype=None,
|
| 158 |
+
):
|
| 159 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 160 |
+
super().__init__()
|
| 161 |
+
out_features = out_features if out_features is not None else in_features
|
| 162 |
+
hidden_features = (
|
| 163 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 164 |
+
)
|
| 165 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 166 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 167 |
+
raise ImportError("fused_dense is not installed")
|
| 168 |
+
self.fc1 = ColumnParallelLinear(
|
| 169 |
+
in_features,
|
| 170 |
+
2 * hidden_features,
|
| 171 |
+
process_group,
|
| 172 |
+
bias=bias1,
|
| 173 |
+
sequence_parallel=sequence_parallel,
|
| 174 |
+
**factory_kwargs,
|
| 175 |
+
)
|
| 176 |
+
self.activation = activation
|
| 177 |
+
self.fc2 = RowParallelLinear(
|
| 178 |
+
hidden_features,
|
| 179 |
+
out_features,
|
| 180 |
+
process_group,
|
| 181 |
+
bias=bias2,
|
| 182 |
+
sequence_parallel=sequence_parallel,
|
| 183 |
+
**factory_kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
y = self.fc1(x)
|
| 188 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
| 189 |
+
y = F.glu(y, dim=-1)
|
| 190 |
+
else:
|
| 191 |
+
y, gate = y.chunk(2, dim=-1)
|
| 192 |
+
y = y * self.activation(gate)
|
| 193 |
+
y = self.fc2(y)
|
| 194 |
+
return y
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b416a3cd3ba537175d19829b09324707b9ee80bb603518dce2f3a2284feea76
|
| 3 |
+
size 556892306
|
modeling_xlm_roberta.py
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
| 2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
| 3 |
+
# Copyright (c) 2022, Tri Dao.
|
| 4 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 5 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
| 6 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 9 |
+
|
| 10 |
+
import importlib.util
|
| 11 |
+
import logging
|
| 12 |
+
import re
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from collections.abc import Sequence
|
| 15 |
+
from functools import partial
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from transformers import PretrainedConfig
|
| 25 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 26 |
+
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
| 27 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
| 28 |
+
|
| 29 |
+
from transformers.models.bert.modeling_bert import (
|
| 30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
+
BertForPreTrainingOutput,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from typing import List, Optional, Tuple, Union
|
| 35 |
+
|
| 36 |
+
from .xlm_padding import (
|
| 37 |
+
index_first_axis,
|
| 38 |
+
index_first_axis_residual,
|
| 39 |
+
pad_input,
|
| 40 |
+
unpad_input,
|
| 41 |
+
)
|
| 42 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 43 |
+
from .block import Block
|
| 44 |
+
from .embedding import XLMRobertaEmbeddings
|
| 45 |
+
from .mha import MHA
|
| 46 |
+
from .mlp import FusedMLP, Mlp
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 50 |
+
except ImportError:
|
| 51 |
+
FusedDense = None
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
| 55 |
+
except ImportError:
|
| 56 |
+
layer_norm_fn = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 61 |
+
except ImportError:
|
| 62 |
+
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
from tqdm.autonotebook import trange
|
| 66 |
+
except ImportError:
|
| 67 |
+
trange = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
| 74 |
+
if not getattr(config, "use_flash_attn", False):
|
| 75 |
+
return False
|
| 76 |
+
if not torch.cuda.is_available():
|
| 77 |
+
return False
|
| 78 |
+
if importlib.util.find_spec("flash_attn") is None:
|
| 79 |
+
logger.warning(
|
| 80 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
| 81 |
+
)
|
| 82 |
+
return False
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 87 |
+
use_flash_attn = get_use_flash_attn(config)
|
| 88 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 89 |
+
|
| 90 |
+
mixer_cls = partial(
|
| 91 |
+
MHA,
|
| 92 |
+
num_heads=config.num_attention_heads,
|
| 93 |
+
cross_attn=cross_attn,
|
| 94 |
+
dropout=config.attention_probs_dropout_prob,
|
| 95 |
+
causal=False,
|
| 96 |
+
fused_bias_fc=fused_bias_fc,
|
| 97 |
+
use_flash_attn=use_flash_attn,
|
| 98 |
+
return_residual=return_residual,
|
| 99 |
+
)
|
| 100 |
+
return mixer_cls
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
| 104 |
+
inner_dim = config.intermediate_size
|
| 105 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 106 |
+
if fused_mlp:
|
| 107 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
| 108 |
+
"fused_mlp only " "supports approximate gelu"
|
| 109 |
+
)
|
| 110 |
+
if not fused_mlp:
|
| 111 |
+
approximate = (
|
| 112 |
+
"tanh"
|
| 113 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 114 |
+
else "none"
|
| 115 |
+
)
|
| 116 |
+
mlp_cls = partial(
|
| 117 |
+
Mlp,
|
| 118 |
+
hidden_features=inner_dim,
|
| 119 |
+
activation=partial(F.gelu, approximate=approximate),
|
| 120 |
+
return_residual=return_residual,
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
if FusedMLP is None:
|
| 124 |
+
raise ImportError("fused_dense is not installed")
|
| 125 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 126 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
| 127 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
| 128 |
+
assert layer_idx is not None
|
| 129 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 130 |
+
mlp_cls = partial(
|
| 131 |
+
FusedMLP,
|
| 132 |
+
hidden_features=inner_dim,
|
| 133 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 134 |
+
return_residual=return_residual,
|
| 135 |
+
)
|
| 136 |
+
return mlp_cls
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def create_block(config, layer_idx=None):
|
| 140 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 141 |
+
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
| 142 |
+
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
| 143 |
+
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
| 144 |
+
# one layer) so we just choose not to return residual in this case.
|
| 145 |
+
return_residual = not cross_attn
|
| 146 |
+
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
| 147 |
+
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
| 148 |
+
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
| 149 |
+
block = Block(
|
| 150 |
+
config.hidden_size,
|
| 151 |
+
mixer_cls,
|
| 152 |
+
mlp_cls,
|
| 153 |
+
norm_cls=norm_cls,
|
| 154 |
+
prenorm=False,
|
| 155 |
+
resid_dropout1=config.hidden_dropout_prob,
|
| 156 |
+
resid_dropout2=config.hidden_dropout_prob,
|
| 157 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 158 |
+
return_residual=return_residual,
|
| 159 |
+
)
|
| 160 |
+
return block
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
| 164 |
+
def _init_weights(module, initializer_range=0.02):
|
| 165 |
+
if isinstance(module, nn.Linear):
|
| 166 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 167 |
+
if module.bias is not None:
|
| 168 |
+
nn.init.zeros_(module.bias)
|
| 169 |
+
elif isinstance(module, nn.Embedding):
|
| 170 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 171 |
+
if module.padding_idx is not None:
|
| 172 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class XLMRobertaEncoder(nn.Module):
|
| 176 |
+
def __init__(self, config: XLMRobertaFlashConfig):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.use_flash_attn = get_use_flash_attn(config)
|
| 179 |
+
self.layers = nn.ModuleList(
|
| 180 |
+
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 181 |
+
)
|
| 182 |
+
self._grad_checkpointing = False
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def gradient_checkpointing(self):
|
| 186 |
+
return self._grad_checkpointing
|
| 187 |
+
|
| 188 |
+
@gradient_checkpointing.setter
|
| 189 |
+
def gradient_checkpointing(self, value):
|
| 190 |
+
self._grad_checkpointing = value
|
| 191 |
+
|
| 192 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 193 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 194 |
+
This means that we only compute the last layer output for these tokens.
|
| 195 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 196 |
+
"""
|
| 197 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
| 198 |
+
mixer_kwargs = (
|
| 199 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
| 200 |
+
if key_padding_mask is not None
|
| 201 |
+
else None
|
| 202 |
+
)
|
| 203 |
+
for layer in self.layers:
|
| 204 |
+
if self._grad_checkpointing:
|
| 205 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 206 |
+
layer,
|
| 207 |
+
hidden_states,
|
| 208 |
+
use_reentrant=False,
|
| 209 |
+
mixer_kwargs=mixer_kwargs,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 213 |
+
if subset_mask is not None:
|
| 214 |
+
hidden_states = hidden_states[subset_mask]
|
| 215 |
+
else:
|
| 216 |
+
batch, seqlen = hidden_states.shape[:2]
|
| 217 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 218 |
+
hidden_states, key_padding_mask
|
| 219 |
+
)
|
| 220 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 221 |
+
if subset_mask is None:
|
| 222 |
+
for layer in self.layers:
|
| 223 |
+
if self._grad_checkpointing:
|
| 224 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 225 |
+
layer,
|
| 226 |
+
hidden_states,
|
| 227 |
+
use_reentrant=False,
|
| 228 |
+
mixer_kwargs=mixer_kwargs,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 232 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 233 |
+
else:
|
| 234 |
+
for layer in self.layers[:-1]:
|
| 235 |
+
if self._grad_checkpointing:
|
| 236 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 237 |
+
layer,
|
| 238 |
+
hidden_states,
|
| 239 |
+
use_reentrant=False,
|
| 240 |
+
mixer_kwargs=mixer_kwargs,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 244 |
+
if key_padding_mask is not None:
|
| 245 |
+
subset_idx = torch.nonzero(
|
| 246 |
+
subset_mask[key_padding_mask], as_tuple=False
|
| 247 |
+
).flatten()
|
| 248 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
| 249 |
+
dim=-1, dtype=torch.int32
|
| 250 |
+
)
|
| 251 |
+
subset_cu_seqlens = F.pad(
|
| 252 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 253 |
+
(1, 0),
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
| 257 |
+
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
| 258 |
+
subset_cu_seqlens = F.pad(
|
| 259 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 260 |
+
(1, 0),
|
| 261 |
+
)
|
| 262 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 263 |
+
hidden_states, subset_idx
|
| 264 |
+
)
|
| 265 |
+
# It's ok to set max_seqlen_q to be much larger
|
| 266 |
+
mixer_kwargs = {
|
| 267 |
+
"x_kv": hidden_states,
|
| 268 |
+
"cu_seqlens": subset_cu_seqlens,
|
| 269 |
+
"max_seqlen": max_seqlen_in_batch,
|
| 270 |
+
"cu_seqlens_k": cu_seqlens,
|
| 271 |
+
"max_seqlen_k": max_seqlen_in_batch,
|
| 272 |
+
}
|
| 273 |
+
if self._grad_checkpointing:
|
| 274 |
+
torch.utils.checkpoint.checkpoint(
|
| 275 |
+
self.layers[-1],
|
| 276 |
+
hidden_states_subset,
|
| 277 |
+
use_reentrant=False,
|
| 278 |
+
mixer_kwargs=mixer_kwargs,
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
hidden_states = self.layers[-1](
|
| 282 |
+
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
| 283 |
+
)
|
| 284 |
+
return hidden_states
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class XLMRobertaPooler(nn.Module):
|
| 288 |
+
def __init__(self, config):
|
| 289 |
+
super().__init__()
|
| 290 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 291 |
+
if fused_bias_fc and FusedDense is None:
|
| 292 |
+
raise ImportError("fused_dense is not installed")
|
| 293 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 294 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 295 |
+
self.activation = nn.Tanh()
|
| 296 |
+
|
| 297 |
+
def forward(self, hidden_states, pool=True):
|
| 298 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 299 |
+
# to the first token.
|
| 300 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 301 |
+
pooled_output = self.dense(first_token_tensor)
|
| 302 |
+
pooled_output = self.activation(pooled_output)
|
| 303 |
+
return pooled_output
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class XLMRobertaPredictionHeadTransform(nn.Module):
|
| 307 |
+
def __init__(self, config):
|
| 308 |
+
super().__init__()
|
| 309 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 310 |
+
if fused_bias_fc and FusedDense is None:
|
| 311 |
+
raise ImportError("fused_dense is not installed")
|
| 312 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 313 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 314 |
+
raise ImportError("Triton is not installed")
|
| 315 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 316 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 317 |
+
approximate = (
|
| 318 |
+
"tanh"
|
| 319 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 320 |
+
else "none"
|
| 321 |
+
)
|
| 322 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
| 323 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 324 |
+
|
| 325 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 326 |
+
hidden_states = self.dense(hidden_states)
|
| 327 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 328 |
+
if not self.fused_dropout_add_ln:
|
| 329 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 330 |
+
else:
|
| 331 |
+
hidden_states = layer_norm_fn(
|
| 332 |
+
hidden_states,
|
| 333 |
+
self.layer_norm.weight,
|
| 334 |
+
self.layer_norm.bias,
|
| 335 |
+
eps=self.layer_norm.eps,
|
| 336 |
+
)
|
| 337 |
+
return hidden_states
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class XLMRobertaLMPredictionHead(nn.Module):
|
| 341 |
+
def __init__(self, config):
|
| 342 |
+
super().__init__()
|
| 343 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 344 |
+
if fused_bias_fc and FusedDense is None:
|
| 345 |
+
raise ImportError("fused_dense is not installed")
|
| 346 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 347 |
+
|
| 348 |
+
self.transform = XLMRobertaPredictionHeadTransform(config)
|
| 349 |
+
|
| 350 |
+
# The output weights are the same as the input embeddings, but there is
|
| 351 |
+
# an output-only bias for each token.
|
| 352 |
+
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
| 353 |
+
|
| 354 |
+
def forward(self, hidden_states):
|
| 355 |
+
hidden_states = self.transform(hidden_states)
|
| 356 |
+
hidden_states = self.decoder(hidden_states)
|
| 357 |
+
return hidden_states
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class XLMRobertaPreTrainingHeads(nn.Module):
|
| 361 |
+
def __init__(self, config):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.predictions = XLMRobertaLMPredictionHead(config)
|
| 364 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 365 |
+
|
| 366 |
+
def forward(self, sequence_output, pooled_output):
|
| 367 |
+
prediction_scores = self.predictions(sequence_output)
|
| 368 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 369 |
+
return prediction_scores, seq_relationship_score
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class XLMRobertaPreTrainedModel(PreTrainedModel):
|
| 373 |
+
"""An abstract class to handle weights initialization and
|
| 374 |
+
a simple interface for dowloading and loading pretrained models.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
config_class = XLMRobertaFlashConfig
|
| 378 |
+
base_model_prefix = "roberta"
|
| 379 |
+
supports_gradient_checkpointing = True
|
| 380 |
+
|
| 381 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 382 |
+
if isinstance(module, XLMRobertaEncoder):
|
| 383 |
+
module.gradient_checkpointing = value
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def from_pretrained(
|
| 387 |
+
cls,
|
| 388 |
+
*args,
|
| 389 |
+
**kwargs,
|
| 390 |
+
):
|
| 391 |
+
if not 'torch_dtype' in kwargs:
|
| 392 |
+
kwargs['torch_dtype'] = 'auto'
|
| 393 |
+
return super().from_pretrained(*args, **kwargs)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 398 |
+
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 399 |
+
super().__init__(config)
|
| 400 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 401 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
| 402 |
+
config.vocab_size += self.pad_vocab_size_multiple - (
|
| 403 |
+
config.vocab_size % self.pad_vocab_size_multiple
|
| 404 |
+
)
|
| 405 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 406 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 407 |
+
raise ImportError("Triton is not installed")
|
| 408 |
+
assert config.hidden_act in [
|
| 409 |
+
"gelu",
|
| 410 |
+
"gelu_new",
|
| 411 |
+
"gelu_fast",
|
| 412 |
+
"gelu_pytorch_tanh",
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
+
self.embeddings = XLMRobertaEmbeddings(
|
| 416 |
+
config.hidden_size,
|
| 417 |
+
config.vocab_size,
|
| 418 |
+
config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
|
| 419 |
+
config.type_vocab_size,
|
| 420 |
+
padding_idx=config.pad_token_id,
|
| 421 |
+
)
|
| 422 |
+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
| 423 |
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 424 |
+
self.encoder = XLMRobertaEncoder(config)
|
| 425 |
+
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
| 426 |
+
|
| 427 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@torch.inference_mode()
|
| 431 |
+
def encode(
|
| 432 |
+
self: 'XLMRobertaModel',
|
| 433 |
+
sentences: Union[str, List[str]],
|
| 434 |
+
batch_size: int = 32,
|
| 435 |
+
show_progress_bar: Optional[bool] = None,
|
| 436 |
+
output_value: str = 'sentence_embedding',
|
| 437 |
+
convert_to_numpy: bool = True,
|
| 438 |
+
convert_to_tensor: bool = False,
|
| 439 |
+
device: Optional[torch.device] = None,
|
| 440 |
+
normalize_embeddings: bool = False,
|
| 441 |
+
truncate_dim: Optional[int] = None,
|
| 442 |
+
**tokenizer_kwargs,
|
| 443 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 444 |
+
"""
|
| 445 |
+
Computes sentence embeddings
|
| 446 |
+
Args:
|
| 447 |
+
sentences(`str` or `List[str]`):
|
| 448 |
+
Sentence or sentences to be encoded
|
| 449 |
+
batch_size(`int`, *optional*, defaults to 32):
|
| 450 |
+
Batch size for the computation
|
| 451 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
| 452 |
+
Show a progress bar when encoding sentences.
|
| 453 |
+
If set to None, progress bar is only shown when
|
| 454 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
| 455 |
+
output_value(`str`, *optional*, defaults to 'sentence_embedding'):
|
| 456 |
+
Default sentence_embedding, to get sentence embeddings.
|
| 457 |
+
Can be set to token_embeddings to get wordpiece token embeddings.
|
| 458 |
+
Set to None, to get all output values
|
| 459 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
| 460 |
+
If true, the output is a list of numpy vectors.
|
| 461 |
+
Else, it is a list of pytorch tensors.
|
| 462 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
| 463 |
+
If true, you get one large tensor as return.
|
| 464 |
+
Overwrites any setting from convert_to_numpy
|
| 465 |
+
device(`torch.device`, *optional*, defaults to None):
|
| 466 |
+
Which torch.device to use for the computation
|
| 467 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
| 468 |
+
If set to true, returned vectors will have length 1. In that case, the
|
| 469 |
+
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 470 |
+
be used.
|
| 471 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
| 472 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 473 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 474 |
+
Keyword arguments for the tokenizer
|
| 475 |
+
Returns:
|
| 476 |
+
By default, a list of tensors is returned.
|
| 477 |
+
If convert_to_tensor, a stacked tensor is returned.
|
| 478 |
+
If convert_to_numpy, a numpy matrix is returned.
|
| 479 |
+
"""
|
| 480 |
+
from transformers import AutoTokenizer
|
| 481 |
+
|
| 482 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 483 |
+
self.name_or_path, trust_remote_code=True
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
is_training = self.training
|
| 487 |
+
self.eval()
|
| 488 |
+
|
| 489 |
+
if show_progress_bar is None:
|
| 490 |
+
show_progress_bar = (
|
| 491 |
+
logger.getEffectiveLevel() == logging.INFO
|
| 492 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if convert_to_tensor:
|
| 496 |
+
convert_to_numpy = False
|
| 497 |
+
|
| 498 |
+
if output_value != 'sentence_embedding':
|
| 499 |
+
convert_to_tensor = False
|
| 500 |
+
convert_to_numpy = False
|
| 501 |
+
|
| 502 |
+
input_was_string = False
|
| 503 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
| 504 |
+
sentences = [sentences]
|
| 505 |
+
input_was_string = True
|
| 506 |
+
|
| 507 |
+
if device is not None:
|
| 508 |
+
self.to(device)
|
| 509 |
+
|
| 510 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
| 511 |
+
inverse_permutation = np.argsort(permutation)
|
| 512 |
+
sentences = [sentences[idx] for idx in permutation]
|
| 513 |
+
|
| 514 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
| 515 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
|
| 516 |
+
'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
|
| 517 |
+
)
|
| 518 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
| 519 |
+
|
| 520 |
+
all_embeddings = []
|
| 521 |
+
|
| 522 |
+
if trange is not None:
|
| 523 |
+
range_iter = trange(
|
| 524 |
+
0,
|
| 525 |
+
len(sentences),
|
| 526 |
+
batch_size,
|
| 527 |
+
desc="Encoding",
|
| 528 |
+
disable=not show_progress_bar,
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
range_iter = range(0, len(sentences), batch_size)
|
| 532 |
+
|
| 533 |
+
for i in range_iter:
|
| 534 |
+
encoded_input = self.tokenizer(
|
| 535 |
+
sentences[i : i + batch_size],
|
| 536 |
+
return_tensors='pt',
|
| 537 |
+
**tokenizer_kwargs,
|
| 538 |
+
).to(self.device)
|
| 539 |
+
token_embs = self.forward(**encoded_input)[0]
|
| 540 |
+
|
| 541 |
+
# Accumulate in fp32 to avoid overflow
|
| 542 |
+
token_embs = token_embs.float()
|
| 543 |
+
|
| 544 |
+
if output_value == 'token_embeddings':
|
| 545 |
+
raise NotImplementedError
|
| 546 |
+
elif output_value is None:
|
| 547 |
+
raise NotImplementedError
|
| 548 |
+
else:
|
| 549 |
+
if self.config.emb_pooler == 'cls':
|
| 550 |
+
embeddings = self.cls_pooling(
|
| 551 |
+
token_embs, encoded_input['attention_mask']
|
| 552 |
+
)
|
| 553 |
+
else:
|
| 554 |
+
embeddings = self.mean_pooling(
|
| 555 |
+
token_embs, encoded_input['attention_mask']
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if normalize_embeddings:
|
| 559 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 560 |
+
|
| 561 |
+
if convert_to_numpy:
|
| 562 |
+
embeddings = embeddings.cpu()
|
| 563 |
+
all_embeddings.extend(embeddings)
|
| 564 |
+
|
| 565 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 566 |
+
|
| 567 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 568 |
+
if truncate_dim:
|
| 569 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 570 |
+
|
| 571 |
+
if convert_to_tensor:
|
| 572 |
+
all_embeddings = torch.stack(all_embeddings)
|
| 573 |
+
elif convert_to_numpy:
|
| 574 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 575 |
+
|
| 576 |
+
if input_was_string:
|
| 577 |
+
all_embeddings = all_embeddings[0]
|
| 578 |
+
|
| 579 |
+
self.train(is_training)
|
| 580 |
+
return all_embeddings
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 584 |
+
if not self.config.matryoshka_dimensions:
|
| 585 |
+
logger.warning(
|
| 586 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
| 587 |
+
)
|
| 588 |
+
return embeddings
|
| 589 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 590 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
| 591 |
+
else:
|
| 592 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
| 593 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
| 594 |
+
|
| 595 |
+
def mean_pooling(
|
| 596 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 597 |
+
):
|
| 598 |
+
input_mask_expanded = (
|
| 599 |
+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 600 |
+
)
|
| 601 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 602 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def cls_pooling(
|
| 607 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 608 |
+
):
|
| 609 |
+
return token_embeddings[:,0]
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def forward(
|
| 613 |
+
self,
|
| 614 |
+
input_ids,
|
| 615 |
+
position_ids=None,
|
| 616 |
+
token_type_ids=None,
|
| 617 |
+
attention_mask=None,
|
| 618 |
+
masked_tokens_mask=None,
|
| 619 |
+
return_dict=None,
|
| 620 |
+
**kwargs,
|
| 621 |
+
):
|
| 622 |
+
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
|
| 623 |
+
we only want the output for the masked tokens. This means that we only compute the last
|
| 624 |
+
layer output for these tokens.
|
| 625 |
+
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
if kwargs:
|
| 629 |
+
for key, value in kwargs.items():
|
| 630 |
+
if value is not None:
|
| 631 |
+
logger.warning(
|
| 632 |
+
'Flash attention implementation does not support kwargs: %s',
|
| 633 |
+
key,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return_dict = (
|
| 637 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
hidden_states = self.embeddings(
|
| 641 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 642 |
+
)
|
| 643 |
+
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 644 |
+
# BERT puts embedding LayerNorm before embedding dropout.
|
| 645 |
+
if not self.fused_dropout_add_ln:
|
| 646 |
+
hidden_states = self.emb_ln(hidden_states)
|
| 647 |
+
else:
|
| 648 |
+
hidden_states = layer_norm_fn(
|
| 649 |
+
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
| 650 |
+
)
|
| 651 |
+
hidden_states = self.emb_drop(hidden_states)
|
| 652 |
+
|
| 653 |
+
if masked_tokens_mask is not None:
|
| 654 |
+
batch_size, seqlen = input_ids.shape[:2]
|
| 655 |
+
# We also need the first column for the CLS token
|
| 656 |
+
first_col_mask = torch.zeros(
|
| 657 |
+
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
| 658 |
+
)
|
| 659 |
+
first_col_mask[:, 0] = True
|
| 660 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
| 661 |
+
else:
|
| 662 |
+
subset_mask = None
|
| 663 |
+
|
| 664 |
+
sequence_output = self.encoder(
|
| 665 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if masked_tokens_mask is None:
|
| 669 |
+
pooled_output = (
|
| 670 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
| 674 |
+
if attention_mask is not None:
|
| 675 |
+
subset_idx = subset_mask[attention_mask]
|
| 676 |
+
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
| 677 |
+
sequence_output = sequence_output[
|
| 678 |
+
masked_tokens_mask[attention_mask][subset_idx]
|
| 679 |
+
]
|
| 680 |
+
else:
|
| 681 |
+
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 682 |
+
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 683 |
+
pooled_output = (
|
| 684 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
if not return_dict:
|
| 688 |
+
return sequence_output, pooled_output
|
| 689 |
+
|
| 690 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 691 |
+
last_hidden_state=sequence_output,
|
| 692 |
+
pooler_output=pooled_output,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
| 697 |
+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
| 698 |
+
|
| 699 |
+
def __init__(self, config):
|
| 700 |
+
super().__init__(config)
|
| 701 |
+
|
| 702 |
+
if config.is_decoder:
|
| 703 |
+
logger.warning(
|
| 704 |
+
"If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
|
| 705 |
+
"bi-directional self-attention."
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
| 709 |
+
self.lm_head = XLMRobertaLMHead(config)
|
| 710 |
+
|
| 711 |
+
# Initialize weights and apply final processing
|
| 712 |
+
self.post_init()
|
| 713 |
+
|
| 714 |
+
def get_input_embeddings(self):
|
| 715 |
+
return self.roberta.embeddings.word_embeddings
|
| 716 |
+
|
| 717 |
+
def get_output_embeddings(self):
|
| 718 |
+
return self.lm_head.decoder
|
| 719 |
+
|
| 720 |
+
def set_output_embeddings(self, new_embeddings):
|
| 721 |
+
self.lm_head.decoder = new_embeddings
|
| 722 |
+
|
| 723 |
+
def forward(
|
| 724 |
+
self,
|
| 725 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 726 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 727 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 728 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 729 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 730 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 731 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 732 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 733 |
+
labels: Optional[torch.LongTensor] = None,
|
| 734 |
+
output_attentions: Optional[bool] = None,
|
| 735 |
+
output_hidden_states: Optional[bool] = None,
|
| 736 |
+
return_dict: Optional[bool] = None,
|
| 737 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 738 |
+
r"""
|
| 739 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 740 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 741 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 742 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 743 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 744 |
+
Used to hide legacy arguments that have been deprecated.
|
| 745 |
+
"""
|
| 746 |
+
return_dict = (
|
| 747 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
outputs = self.roberta(
|
| 751 |
+
input_ids,
|
| 752 |
+
attention_mask=attention_mask,
|
| 753 |
+
token_type_ids=token_type_ids,
|
| 754 |
+
position_ids=position_ids,
|
| 755 |
+
head_mask=head_mask,
|
| 756 |
+
inputs_embeds=inputs_embeds,
|
| 757 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 758 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 759 |
+
output_attentions=output_attentions,
|
| 760 |
+
output_hidden_states=output_hidden_states,
|
| 761 |
+
return_dict=return_dict,
|
| 762 |
+
)
|
| 763 |
+
sequence_output = outputs[0]
|
| 764 |
+
prediction_scores = self.lm_head(sequence_output)
|
| 765 |
+
|
| 766 |
+
masked_lm_loss = None
|
| 767 |
+
if labels is not None:
|
| 768 |
+
# move labels to correct device to enable model parallelism
|
| 769 |
+
labels = labels.to(prediction_scores.device)
|
| 770 |
+
loss_fct = CrossEntropyLoss()
|
| 771 |
+
masked_lm_loss = loss_fct(
|
| 772 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
if not return_dict:
|
| 776 |
+
output = (prediction_scores,) + outputs[2:]
|
| 777 |
+
return (
|
| 778 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
return MaskedLMOutput(
|
| 782 |
+
loss=masked_lm_loss,
|
| 783 |
+
logits=prediction_scores,
|
| 784 |
+
hidden_states=outputs.hidden_states,
|
| 785 |
+
attentions=outputs.attentions,
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
|
| 790 |
+
class XLMRobertaClassificationHead(nn.Module):
|
| 791 |
+
"""Head for sentence-level classification tasks."""
|
| 792 |
+
|
| 793 |
+
def __init__(self, config):
|
| 794 |
+
super().__init__()
|
| 795 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 796 |
+
if fused_bias_fc and FusedDense is None:
|
| 797 |
+
raise ImportError("fused_dense is not installed")
|
| 798 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 799 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 800 |
+
classifier_dropout = (
|
| 801 |
+
config.classifier_dropout
|
| 802 |
+
if config.classifier_dropout is not None
|
| 803 |
+
else config.hidden_dropout_prob
|
| 804 |
+
)
|
| 805 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 806 |
+
self.out_proj = linear_cls(config.hidden_size, config.num_labels)
|
| 807 |
+
|
| 808 |
+
def forward(self, features, **kwargs):
|
| 809 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 810 |
+
x = self.dropout(x)
|
| 811 |
+
x = self.dense(x)
|
| 812 |
+
x = torch.tanh(x)
|
| 813 |
+
x = self.dropout(x)
|
| 814 |
+
x = self.out_proj(x)
|
| 815 |
+
return x
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
|
| 819 |
+
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
| 820 |
+
def __init__(self, config):
|
| 821 |
+
super().__init__(config)
|
| 822 |
+
self.num_labels = config.num_labels
|
| 823 |
+
self.config = config
|
| 824 |
+
|
| 825 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
| 826 |
+
self.classifier = XLMRobertaClassificationHead(config)
|
| 827 |
+
|
| 828 |
+
# Initialize weights and apply final processing
|
| 829 |
+
self.post_init()
|
| 830 |
+
|
| 831 |
+
def forward(
|
| 832 |
+
self,
|
| 833 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 834 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 835 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 836 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 837 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 838 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 839 |
+
labels: Optional[torch.LongTensor] = None,
|
| 840 |
+
output_attentions: Optional[bool] = None,
|
| 841 |
+
output_hidden_states: Optional[bool] = None,
|
| 842 |
+
return_dict: Optional[bool] = None,
|
| 843 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 844 |
+
r"""
|
| 845 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 846 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 847 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 848 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 849 |
+
"""
|
| 850 |
+
return_dict = (
|
| 851 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
outputs = self.roberta(
|
| 855 |
+
input_ids,
|
| 856 |
+
attention_mask=attention_mask,
|
| 857 |
+
token_type_ids=token_type_ids,
|
| 858 |
+
position_ids=position_ids,
|
| 859 |
+
head_mask=head_mask,
|
| 860 |
+
inputs_embeds=inputs_embeds,
|
| 861 |
+
output_attentions=output_attentions,
|
| 862 |
+
output_hidden_states=output_hidden_states,
|
| 863 |
+
return_dict=return_dict,
|
| 864 |
+
)
|
| 865 |
+
sequence_output = outputs[0]
|
| 866 |
+
logits = self.classifier(sequence_output)
|
| 867 |
+
|
| 868 |
+
loss = None
|
| 869 |
+
if labels is not None:
|
| 870 |
+
# move labels to correct device to enable model parallelism
|
| 871 |
+
labels = labels.to(logits.device)
|
| 872 |
+
if self.config.problem_type is None:
|
| 873 |
+
if self.num_labels == 1:
|
| 874 |
+
self.config.problem_type = "regression"
|
| 875 |
+
elif self.num_labels > 1 and (
|
| 876 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
| 877 |
+
):
|
| 878 |
+
self.config.problem_type = "single_label_classification"
|
| 879 |
+
else:
|
| 880 |
+
self.config.problem_type = "multi_label_classification"
|
| 881 |
+
|
| 882 |
+
if self.config.problem_type == "regression":
|
| 883 |
+
loss_fct = MSELoss()
|
| 884 |
+
if self.num_labels == 1:
|
| 885 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 886 |
+
else:
|
| 887 |
+
loss = loss_fct(logits, labels)
|
| 888 |
+
elif self.config.problem_type == "single_label_classification":
|
| 889 |
+
loss_fct = CrossEntropyLoss()
|
| 890 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 891 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 892 |
+
loss_fct = BCEWithLogitsLoss()
|
| 893 |
+
loss = loss_fct(logits, labels)
|
| 894 |
+
|
| 895 |
+
if not return_dict:
|
| 896 |
+
output = (logits,) + outputs[2:]
|
| 897 |
+
return ((loss,) + output) if loss is not None else output
|
| 898 |
+
|
| 899 |
+
return SequenceClassifierOutput(
|
| 900 |
+
loss=loss,
|
| 901 |
+
logits=logits,
|
| 902 |
+
hidden_states=outputs.hidden_states,
|
| 903 |
+
attentions=outputs.attentions,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
@torch.inference_mode()
|
| 908 |
+
def compute_score(
|
| 909 |
+
self,
|
| 910 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
| 911 |
+
batch_size: int = 32,
|
| 912 |
+
max_length: Optional[int] = None,
|
| 913 |
+
) -> List[float]:
|
| 914 |
+
|
| 915 |
+
if not hasattr(self, "_tokenizer"):
|
| 916 |
+
from transformers import AutoTokenizer
|
| 917 |
+
|
| 918 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 919 |
+
self.name_or_path, trust_remote_code=True
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
assert isinstance(sentence_pairs, list)
|
| 923 |
+
if isinstance(sentence_pairs[0], str):
|
| 924 |
+
sentence_pairs = [sentence_pairs]
|
| 925 |
+
|
| 926 |
+
all_scores = []
|
| 927 |
+
for start_index in range(
|
| 928 |
+
0, len(sentence_pairs), batch_size
|
| 929 |
+
):
|
| 930 |
+
sentences_batch = sentence_pairs[
|
| 931 |
+
start_index : start_index + batch_size
|
| 932 |
+
]
|
| 933 |
+
inputs = self._tokenizer(
|
| 934 |
+
sentences_batch,
|
| 935 |
+
padding=True,
|
| 936 |
+
truncation=True,
|
| 937 |
+
return_tensors='pt',
|
| 938 |
+
max_length=max_length,
|
| 939 |
+
).to(self.device)
|
| 940 |
+
scores = (
|
| 941 |
+
self.forward(**inputs, return_dict=True)
|
| 942 |
+
.logits.view(
|
| 943 |
+
-1,
|
| 944 |
+
)
|
| 945 |
+
.float()
|
| 946 |
+
)
|
| 947 |
+
scores = torch.sigmoid(scores)
|
| 948 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
| 949 |
+
|
| 950 |
+
if len(all_scores) == 1:
|
| 951 |
+
return all_scores[0]
|
| 952 |
+
return all_scores
|
| 953 |
+
|
| 954 |
+
def predict(
|
| 955 |
+
self,
|
| 956 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
| 957 |
+
batch_size: int = 32,
|
| 958 |
+
max_length: Optional[int] = None,
|
| 959 |
+
) -> List[float]:
|
| 960 |
+
# used for beir evaluation
|
| 961 |
+
return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
|
| 962 |
+
|
| 963 |
+
def rerank(
|
| 964 |
+
self,
|
| 965 |
+
query: str,
|
| 966 |
+
documents: List[str],
|
| 967 |
+
batch_size: int = 32,
|
| 968 |
+
max_length: int = 1024,
|
| 969 |
+
max_query_length: int = 512,
|
| 970 |
+
overlap_tokens: int = 80,
|
| 971 |
+
top_n: Optional[int] = None,
|
| 972 |
+
**kwargs,
|
| 973 |
+
):
|
| 974 |
+
assert max_length >= max_query_length * 2, (
|
| 975 |
+
f'max_length ({max_length}) must be greater than or equal to '
|
| 976 |
+
f'max_query_length ({max_query_length}) * 2'
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
if not hasattr(self, "_tokenizer"):
|
| 980 |
+
from transformers import AutoTokenizer
|
| 981 |
+
|
| 982 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 983 |
+
self.name_or_path, trust_remote_code=True
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# preproc of tokenization
|
| 987 |
+
sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
|
| 988 |
+
query,
|
| 989 |
+
documents,
|
| 990 |
+
tokenizer=self._tokenizer,
|
| 991 |
+
max_length=max_length,
|
| 992 |
+
max_query_length=max_query_length,
|
| 993 |
+
overlap_tokens=overlap_tokens,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
tot_scores = []
|
| 997 |
+
with torch.no_grad():
|
| 998 |
+
for k in range(0, len(sentence_pairs), batch_size):
|
| 999 |
+
batch = self._tokenizer.pad(
|
| 1000 |
+
sentence_pairs[k : k + batch_size],
|
| 1001 |
+
padding=True,
|
| 1002 |
+
max_length=max_length,
|
| 1003 |
+
pad_to_multiple_of=None,
|
| 1004 |
+
return_tensors="pt",
|
| 1005 |
+
)
|
| 1006 |
+
batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
|
| 1007 |
+
scores = (
|
| 1008 |
+
self.forward(**batch_on_device, return_dict=True)
|
| 1009 |
+
.logits.view(
|
| 1010 |
+
-1,
|
| 1011 |
+
)
|
| 1012 |
+
.float()
|
| 1013 |
+
)
|
| 1014 |
+
scores = torch.sigmoid(scores)
|
| 1015 |
+
tot_scores.extend(scores.cpu().numpy().tolist())
|
| 1016 |
+
|
| 1017 |
+
# ranking
|
| 1018 |
+
merge_scores = [0 for _ in range(len(documents))]
|
| 1019 |
+
for pid, score in zip(sentence_pairs_pids, tot_scores):
|
| 1020 |
+
merge_scores[pid] = max(merge_scores[pid], score)
|
| 1021 |
+
|
| 1022 |
+
merge_scores_argsort = np.argsort(merge_scores)[::-1]
|
| 1023 |
+
sorted_documents = []
|
| 1024 |
+
sorted_scores = []
|
| 1025 |
+
for mid in merge_scores_argsort:
|
| 1026 |
+
sorted_scores.append(merge_scores[mid])
|
| 1027 |
+
sorted_documents.append(documents[mid])
|
| 1028 |
+
|
| 1029 |
+
top_n = min(top_n or len(sorted_documents), len(sorted_documents))
|
| 1030 |
+
|
| 1031 |
+
return [
|
| 1032 |
+
{
|
| 1033 |
+
'document': sorted_documents[i],
|
| 1034 |
+
'relevance_score': sorted_scores[i],
|
| 1035 |
+
'index': merge_scores_argsort[i],
|
| 1036 |
+
}
|
| 1037 |
+
for i in range(top_n)
|
| 1038 |
+
]
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def reranker_tokenize_preproc(
|
| 1042 |
+
query: str,
|
| 1043 |
+
passages: List[str],
|
| 1044 |
+
tokenizer=None,
|
| 1045 |
+
max_length: int = 1024,
|
| 1046 |
+
max_query_length: int = 512,
|
| 1047 |
+
overlap_tokens: int = 80,
|
| 1048 |
+
):
|
| 1049 |
+
from copy import deepcopy
|
| 1050 |
+
|
| 1051 |
+
assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
|
| 1052 |
+
sep_id = tokenizer.sep_token_id
|
| 1053 |
+
|
| 1054 |
+
def _merge_inputs(chunk1_raw, chunk2):
|
| 1055 |
+
chunk1 = deepcopy(chunk1_raw)
|
| 1056 |
+
chunk1['input_ids'].append(sep_id)
|
| 1057 |
+
chunk1['input_ids'].extend(chunk2['input_ids'])
|
| 1058 |
+
chunk1['input_ids'].append(sep_id)
|
| 1059 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][0])
|
| 1060 |
+
chunk1['attention_mask'].extend(chunk2['attention_mask'])
|
| 1061 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
|
| 1062 |
+
if 'token_type_ids' in chunk1:
|
| 1063 |
+
token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
|
| 1064 |
+
chunk1['token_type_ids'].extend(token_type_ids)
|
| 1065 |
+
return chunk1
|
| 1066 |
+
|
| 1067 |
+
# Note: the long query will be truncated to 256 tokens by default
|
| 1068 |
+
query_inputs = tokenizer.encode_plus(
|
| 1069 |
+
query, truncation=True, padding=False, max_length=max_query_length
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
|
| 1073 |
+
# assert (
|
| 1074 |
+
# max_passage_inputs_length > 100
|
| 1075 |
+
# ), "Your query is too long! Please make sure your query less than 500 tokens!"
|
| 1076 |
+
|
| 1077 |
+
overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
|
| 1078 |
+
|
| 1079 |
+
res_merge_inputs = []
|
| 1080 |
+
res_merge_inputs_pids = []
|
| 1081 |
+
for pid, passage in enumerate(passages):
|
| 1082 |
+
passage_inputs = tokenizer.encode_plus(
|
| 1083 |
+
passage,
|
| 1084 |
+
truncation=False,
|
| 1085 |
+
padding=False,
|
| 1086 |
+
add_special_tokens=False,
|
| 1087 |
+
max_length=0,
|
| 1088 |
+
)
|
| 1089 |
+
passage_inputs_length = len(passage_inputs['input_ids'])
|
| 1090 |
+
|
| 1091 |
+
if passage_inputs_length <= max_passage_inputs_length:
|
| 1092 |
+
qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
|
| 1093 |
+
res_merge_inputs.append(qp_merge_inputs)
|
| 1094 |
+
res_merge_inputs_pids.append(pid)
|
| 1095 |
+
else:
|
| 1096 |
+
start_id = 0
|
| 1097 |
+
while start_id < passage_inputs_length:
|
| 1098 |
+
end_id = start_id + max_passage_inputs_length
|
| 1099 |
+
# make sure the length of the last chunk is `max_passage_inputs_length`
|
| 1100 |
+
if end_id >= passage_inputs_length:
|
| 1101 |
+
sub_passage_inputs = {
|
| 1102 |
+
k: v[-max_passage_inputs_length:]
|
| 1103 |
+
for k, v in passage_inputs.items()
|
| 1104 |
+
}
|
| 1105 |
+
else:
|
| 1106 |
+
sub_passage_inputs = {
|
| 1107 |
+
k: v[start_id:end_id] for k, v in passage_inputs.items()
|
| 1108 |
+
}
|
| 1109 |
+
start_id = (
|
| 1110 |
+
end_id - overlap_tokens_implt
|
| 1111 |
+
if end_id < passage_inputs_length
|
| 1112 |
+
else end_id
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
|
| 1116 |
+
res_merge_inputs.append(qp_merge_inputs)
|
| 1117 |
+
res_merge_inputs_pids.append(pid)
|
| 1118 |
+
|
| 1119 |
+
return res_merge_inputs, res_merge_inputs_pids
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "<s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "</s>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "<mask>",
|
| 25 |
+
"lstrip": true,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "<pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "</s>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "<unk>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e802fe5337779428818439760a1e6161ed36ceed72d4ebcbda9c139a2108fc99
|
| 3 |
+
size 17082988
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<s>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<pad>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"250001": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": true,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": true,
|
| 46 |
+
"cls_token": "<s>",
|
| 47 |
+
"eos_token": "</s>",
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "<mask>",
|
| 50 |
+
"model_max_length": 1024,
|
| 51 |
+
"pad_token": "<pad>",
|
| 52 |
+
"sep_token": "</s>",
|
| 53 |
+
"tokenizer_class": "XLMRobertaTokenizerFast",
|
| 54 |
+
"unk_token": "<unk>"
|
| 55 |
+
}
|
xlm_padding.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
| 2 |
+
# Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
|
| 3 |
+
|
| 4 |
+
# Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class IndexFirstAxis(torch.autograd.Function):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def forward(ctx, input, indices):
|
| 14 |
+
ctx.save_for_backward(indices)
|
| 15 |
+
assert input.ndim >= 2
|
| 16 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| 17 |
+
second_dim = other_shape.numel()
|
| 18 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 19 |
+
# return input[indices]
|
| 20 |
+
return torch.gather(
|
| 21 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
| 22 |
+
).reshape(-1, *other_shape)
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def backward(ctx, grad_output):
|
| 26 |
+
(indices,) = ctx.saved_tensors
|
| 27 |
+
assert grad_output.ndim >= 2
|
| 28 |
+
other_shape = grad_output.shape[1:]
|
| 29 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
| 30 |
+
grad_input = torch.zeros(
|
| 31 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
| 32 |
+
device=grad_output.device,
|
| 33 |
+
dtype=grad_output.dtype,
|
| 34 |
+
)
|
| 35 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 36 |
+
# grad_input[indices] = grad_output
|
| 37 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
| 38 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
index_first_axis = IndexFirstAxis.apply
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def forward(ctx, values, indices, first_axis_dim):
|
| 47 |
+
ctx.save_for_backward(indices)
|
| 48 |
+
assert indices.ndim == 1
|
| 49 |
+
assert values.ndim >= 2
|
| 50 |
+
output = torch.zeros(
|
| 51 |
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
| 52 |
+
)
|
| 53 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 54 |
+
output[indices] = values
|
| 55 |
+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
| 56 |
+
return output
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def backward(ctx, grad_output):
|
| 60 |
+
(indices,) = ctx.saved_tensors
|
| 61 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 62 |
+
grad_values = grad_output[indices]
|
| 63 |
+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
| 64 |
+
return grad_values, None, None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class IndexFirstAxisResidual(torch.autograd.Function):
|
| 71 |
+
@staticmethod
|
| 72 |
+
def forward(ctx, input, indices):
|
| 73 |
+
ctx.save_for_backward(indices)
|
| 74 |
+
assert input.ndim >= 2
|
| 75 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| 76 |
+
second_dim = other_shape.numel()
|
| 77 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 78 |
+
output = input[indices]
|
| 79 |
+
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
| 80 |
+
# memory format to channel_first. In other words, input might not be contiguous.
|
| 81 |
+
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
| 82 |
+
return output, input.detach()
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def backward(ctx, grad_output, grad_residual):
|
| 86 |
+
(indices,) = ctx.saved_tensors
|
| 87 |
+
assert grad_output.ndim >= 2
|
| 88 |
+
other_shape = grad_output.shape[1:]
|
| 89 |
+
assert grad_residual.shape[1:] == other_shape
|
| 90 |
+
grad_input = grad_residual
|
| 91 |
+
# grad_input[indices] += grad_output
|
| 92 |
+
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
| 93 |
+
indices = indices.expand_as(grad_output)
|
| 94 |
+
grad_input.scatter_add_(0, indices, grad_output)
|
| 95 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
index_first_axis_residual = IndexFirstAxisResidual.apply
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def unpad_input(hidden_states, attention_mask):
|
| 102 |
+
"""
|
| 103 |
+
Arguments:
|
| 104 |
+
hidden_states: (batch, seqlen, ...)
|
| 105 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
| 106 |
+
Return:
|
| 107 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 108 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
| 109 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 110 |
+
max_seqlen_in_batch: int
|
| 111 |
+
"""
|
| 112 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 113 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 114 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 115 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 116 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 117 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 118 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 119 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
| 120 |
+
# so we write custom forward and backward to make it a bit faster.
|
| 121 |
+
return (
|
| 122 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
| 123 |
+
indices,
|
| 124 |
+
cu_seqlens,
|
| 125 |
+
max_seqlen_in_batch,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
|
| 130 |
+
"""
|
| 131 |
+
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
|
| 132 |
+
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
|
| 133 |
+
|
| 134 |
+
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
| 135 |
+
```
|
| 136 |
+
[
|
| 137 |
+
[2, 3, 0, 0, 0, 0],
|
| 138 |
+
[3, 2, 0, 0, 0, 0],
|
| 139 |
+
[6, 0, 0, 0, 0, 0]
|
| 140 |
+
]
|
| 141 |
+
```
|
| 142 |
+
, which refers to the 3D-attention mask:
|
| 143 |
+
```
|
| 144 |
+
[
|
| 145 |
+
[
|
| 146 |
+
[1, 0, 0, 0, 0, 0],
|
| 147 |
+
[1, 1, 0, 0, 0, 0],
|
| 148 |
+
[0, 0, 1, 0, 0, 0],
|
| 149 |
+
[0, 0, 1, 1, 0, 0],
|
| 150 |
+
[0, 0, 1, 1, 1, 0],
|
| 151 |
+
[0, 0, 0, 0, 0, 1]
|
| 152 |
+
],
|
| 153 |
+
[
|
| 154 |
+
[1, 0, 0, 0, 0, 0],
|
| 155 |
+
[1, 1, 0, 0, 0, 0],
|
| 156 |
+
[1, 1, 1, 0, 0, 0],
|
| 157 |
+
[0, 0, 0, 1, 0, 0],
|
| 158 |
+
[0, 0, 0, 1, 1, 0],
|
| 159 |
+
[0, 0, 0, 0, 0, 1]
|
| 160 |
+
],
|
| 161 |
+
[
|
| 162 |
+
[1, 0, 0, 0, 0, 0],
|
| 163 |
+
[1, 1, 0, 0, 0, 0],
|
| 164 |
+
[1, 1, 1, 0, 0, 0],
|
| 165 |
+
[1, 1, 1, 1, 0, 0],
|
| 166 |
+
[1, 1, 1, 1, 1, 0],
|
| 167 |
+
[1, 1, 1, 1, 1, 1]
|
| 168 |
+
]
|
| 169 |
+
]
|
| 170 |
+
```.
|
| 171 |
+
|
| 172 |
+
Arguments:
|
| 173 |
+
hidden_states: (batch, seqlen, ...)
|
| 174 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
| 175 |
+
Return:
|
| 176 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 177 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
| 178 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 179 |
+
max_seqlen_in_batch: int
|
| 180 |
+
"""
|
| 181 |
+
length = attention_mask_in_length.sum(dim=-1)
|
| 182 |
+
seqlen = attention_mask_in_length.size(-1)
|
| 183 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
|
| 184 |
+
seqlen) < length.unsqueeze(
|
| 185 |
+
1)
|
| 186 |
+
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
| 187 |
+
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
| 188 |
+
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
| 189 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 190 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 191 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 192 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 193 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 194 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
| 195 |
+
# so we write custom forward and backward to make it a bit faster.
|
| 196 |
+
return (
|
| 197 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
| 198 |
+
indices,
|
| 199 |
+
cu_seqlens,
|
| 200 |
+
max_seqlen_in_batch,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
| 205 |
+
"""
|
| 206 |
+
Arguments:
|
| 207 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 208 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
| 209 |
+
batch: int, batch size for the padded sequence.
|
| 210 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
| 211 |
+
Return:
|
| 212 |
+
hidden_states: (batch, seqlen, ...)
|
| 213 |
+
"""
|
| 214 |
+
dim = hidden_states.shape[-1]
|
| 215 |
+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 216 |
+
# output[indices] = hidden_states
|
| 217 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 218 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|