lhallee commited on
Commit
bbc0cc5
·
verified ·
1 Parent(s): 0a19ae9

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +5 -5
modeling_e1.py CHANGED
@@ -377,6 +377,11 @@ class Pooler:
377
  attention_mask: Optional[torch.Tensor] = None,
378
  attentions: Optional[torch.Tensor] = None
379
  ) -> torch.Tensor:
 
 
 
 
 
380
  final_emb: List[torch.Tensor] = []
381
  for pooling_type in self.pooling_types:
382
  final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
@@ -2398,7 +2403,6 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2398
  def set_input_embeddings(self, value: nn.Embedding) -> None:
2399
  self.embed_tokens = value
2400
 
2401
- @torch.inference_mode()
2402
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2403
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2404
  last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
@@ -2602,7 +2606,6 @@ class E1Model(E1PreTrainedModel, EmbeddingMixin):
2602
  def set_input_embeddings(self, value: nn.Embedding) -> None:
2603
  self.model.set_input_embeddings(value)
2604
 
2605
- @torch.inference_mode()
2606
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2607
  return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs)
2608
 
@@ -2656,7 +2659,6 @@ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
2656
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2657
  return self.model.device_mesh
2658
 
2659
- @torch.inference_mode()
2660
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2661
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2662
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
@@ -2778,7 +2780,6 @@ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
2778
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2779
  return self.model.device_mesh
2780
 
2781
- @torch.inference_mode()
2782
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2783
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2784
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
@@ -2875,7 +2876,6 @@ class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
2875
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2876
  return self.model.device_mesh
2877
 
2878
- @torch.inference_mode()
2879
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2880
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2881
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
 
377
  attention_mask: Optional[torch.Tensor] = None,
378
  attentions: Optional[torch.Tensor] = None
379
  ) -> torch.Tensor:
380
+ if attention_mask is not None:
381
+ assert attention_mask.sum(dim=-1).min() > 0, (
382
+ "Pooler received samples with all-zero attention masks. "
383
+ "This causes NaN from division by zero. Filter empty inputs before pooling."
384
+ )
385
  final_emb: List[torch.Tensor] = []
386
  for pooling_type in self.pooling_types:
387
  final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
 
2403
  def set_input_embeddings(self, value: nn.Embedding) -> None:
2404
  self.embed_tokens = value
2405
 
 
2406
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2407
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2408
  last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
 
2606
  def set_input_embeddings(self, value: nn.Embedding) -> None:
2607
  self.model.set_input_embeddings(value)
2608
 
 
2609
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2610
  return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs)
2611
 
 
2659
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2660
  return self.model.device_mesh
2661
 
 
2662
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2663
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2664
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
 
2780
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2781
  return self.model.device_mesh
2782
 
 
2783
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2784
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2785
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
 
2876
  def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2877
  return self.model.device_mesh
2878
 
 
2879
  def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2880
  batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2881
  last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state