Update splade.py
Browse files
splade.py
CHANGED
|
@@ -42,7 +42,7 @@ class Splade(PreTrainedModel):
|
|
| 42 |
similarity = similarity
|
| 43 |
encode = encode
|
| 44 |
|
| 45 |
-
def __init__(self, config, weights_path=None):
|
| 46 |
super().__init__(config)
|
| 47 |
self.name = "splade"
|
| 48 |
|
|
@@ -68,6 +68,7 @@ class Splade(PreTrainedModel):
|
|
| 68 |
attn_implementation=config.attn_implementation,
|
| 69 |
bidirectional=getattr(config, "bidirectional", False),
|
| 70 |
base_cfg=base_cfg,
|
|
|
|
| 71 |
)
|
| 72 |
|
| 73 |
def save_pretrained(self, save_directory, *args, **kwargs):
|
|
@@ -83,7 +84,7 @@ class Splade(PreTrainedModel):
|
|
| 83 |
token=token,
|
| 84 |
)
|
| 85 |
|
| 86 |
-
model = cls(config, weights_path=model_name_or_path)
|
| 87 |
|
| 88 |
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
|
| 89 |
return model
|
|
|
|
| 42 |
similarity = similarity
|
| 43 |
encode = encode
|
| 44 |
|
| 45 |
+
def __init__(self, config, weights_path=None, token=None):
|
| 46 |
super().__init__(config)
|
| 47 |
self.name = "splade"
|
| 48 |
|
|
|
|
| 68 |
attn_implementation=config.attn_implementation,
|
| 69 |
bidirectional=getattr(config, "bidirectional", False),
|
| 70 |
base_cfg=base_cfg,
|
| 71 |
+
token=token
|
| 72 |
)
|
| 73 |
|
| 74 |
def save_pretrained(self, save_directory, *args, **kwargs):
|
|
|
|
| 84 |
token=token,
|
| 85 |
)
|
| 86 |
|
| 87 |
+
model = cls(config, weights_path=model_name_or_path, token=token)
|
| 88 |
|
| 89 |
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
|
| 90 |
return model
|