ALJIACHI commited on
Commit
9544d48
·
verified ·
1 Parent(s): ef36cb3

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +12 -13
modeling.py CHANGED
@@ -372,19 +372,18 @@ class NewEmbeddings(nn.Module):
372
  embeddings = inputs_embeds
373
 
374
  # Set and unpad position_ids
375
- if position_ids is None:
376
- if seq_length > self.position_ids.size(0):
377
- self.register_buffer(
378
- "position_ids", torch.arange(seq_length, device=embeddings.device), persistent=False
379
- )
380
- if unpad_inputs:
381
- # [1, cumsum_seq_len]
382
- position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
383
- else:
384
- # [bs, seq_len]
385
- position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
386
- elif unpad_inputs:
387
- position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
388
 
389
  # Compute rotary embedding
390
  if self.position_embedding_type == 'rope':
 
372
  embeddings = inputs_embeds
373
 
374
  # Set and unpad position_ids
375
+ # Always regenerate internally for robustness — avoids corrupt position_ids
376
+ # from external callers (e.g. sentence-transformers on some environments)
377
+ if seq_length > self.position_ids.size(0):
378
+ self.register_buffer(
379
+ "position_ids", torch.arange(seq_length, device=embeddings.device), persistent=False
380
+ )
381
+ if unpad_inputs:
382
+ # [1, cumsum_seq_len]
383
+ position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
384
+ else:
385
+ # [bs, seq_len]
386
+ position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
 
387
 
388
  # Compute rotary embedding
389
  if self.position_embedding_type == 'rope':