Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +5 -5
modeling_e1.py
CHANGED
|
@@ -377,6 +377,11 @@ class Pooler:
|
|
| 377 |
attention_mask: Optional[torch.Tensor] = None,
|
| 378 |
attentions: Optional[torch.Tensor] = None
|
| 379 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
final_emb: List[torch.Tensor] = []
|
| 381 |
for pooling_type in self.pooling_types:
|
| 382 |
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
|
|
@@ -2398,7 +2403,6 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
|
|
| 2398 |
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 2399 |
self.embed_tokens = value
|
| 2400 |
|
| 2401 |
-
@torch.inference_mode()
|
| 2402 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2403 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2404 |
last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
@@ -2602,7 +2606,6 @@ class E1Model(E1PreTrainedModel, EmbeddingMixin):
|
|
| 2602 |
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 2603 |
self.model.set_input_embeddings(value)
|
| 2604 |
|
| 2605 |
-
@torch.inference_mode()
|
| 2606 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2607 |
return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs)
|
| 2608 |
|
|
@@ -2656,7 +2659,6 @@ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
|
|
| 2656 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2657 |
return self.model.device_mesh
|
| 2658 |
|
| 2659 |
-
@torch.inference_mode()
|
| 2660 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2661 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2662 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
@@ -2778,7 +2780,6 @@ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
|
|
| 2778 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2779 |
return self.model.device_mesh
|
| 2780 |
|
| 2781 |
-
@torch.inference_mode()
|
| 2782 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2783 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2784 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
@@ -2875,7 +2876,6 @@ class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
|
|
| 2875 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2876 |
return self.model.device_mesh
|
| 2877 |
|
| 2878 |
-
@torch.inference_mode()
|
| 2879 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2880 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2881 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
|
| 377 |
attention_mask: Optional[torch.Tensor] = None,
|
| 378 |
attentions: Optional[torch.Tensor] = None
|
| 379 |
) -> torch.Tensor:
|
| 380 |
+
if attention_mask is not None:
|
| 381 |
+
assert attention_mask.sum(dim=-1).min() > 0, (
|
| 382 |
+
"Pooler received samples with all-zero attention masks. "
|
| 383 |
+
"This causes NaN from division by zero. Filter empty inputs before pooling."
|
| 384 |
+
)
|
| 385 |
final_emb: List[torch.Tensor] = []
|
| 386 |
for pooling_type in self.pooling_types:
|
| 387 |
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
|
|
|
|
| 2403 |
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 2404 |
self.embed_tokens = value
|
| 2405 |
|
|
|
|
| 2406 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2407 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2408 |
last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
|
| 2606 |
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 2607 |
self.model.set_input_embeddings(value)
|
| 2608 |
|
|
|
|
| 2609 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2610 |
return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs)
|
| 2611 |
|
|
|
|
| 2659 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2660 |
return self.model.device_mesh
|
| 2661 |
|
|
|
|
| 2662 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2663 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2664 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
|
| 2780 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2781 |
return self.model.device_mesh
|
| 2782 |
|
|
|
|
| 2783 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2784 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2785 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
|
| 2876 |
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
| 2877 |
return self.model.device_mesh
|
| 2878 |
|
|
|
|
| 2879 |
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
| 2880 |
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
| 2881 |
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|