Update bert_layers.py
Browse filesPlease refer to this [issue]([url](https://github.com/huggingface/transformers/issues/31068))
- bert_layers.py +4 -1
bert_layers.py
CHANGED
|
@@ -25,6 +25,8 @@ from .bert_padding import (index_first_axis,
|
|
| 25 |
index_put_first_axis, pad_input,
|
| 26 |
unpad_input, unpad_input_only)
|
| 27 |
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
| 30 |
except ImportError as e:
|
|
@@ -564,7 +566,8 @@ class BertModel(BertPreTrainedModel):
|
|
| 564 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| 565 |
```
|
| 566 |
"""
|
| 567 |
-
|
|
|
|
| 568 |
def __init__(self, config, add_pooling_layer=True):
|
| 569 |
super(BertModel, self).__init__(config)
|
| 570 |
self.embeddings = BertEmbeddings(config)
|
|
|
|
| 25 |
index_put_first_axis, pad_input,
|
| 26 |
unpad_input, unpad_input_only)
|
| 27 |
|
| 28 |
+
from .configuration_bert import BertConfig
|
| 29 |
+
|
| 30 |
try:
|
| 31 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
| 32 |
except ImportError as e:
|
|
|
|
| 566 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| 567 |
```
|
| 568 |
"""
|
| 569 |
+
config_class = BertConfig
|
| 570 |
+
|
| 571 |
def __init__(self, config, add_pooling_layer=True):
|
| 572 |
super(BertModel, self).__init__(config)
|
| 573 |
self.embeddings = BertEmbeddings(config)
|