HemanthSai7 commited on
Commit
8fde181
·
verified ·
1 Parent(s): 094f4e0

Fix NandiModel.forward: add cache_position for compatibility with released transformers 5.3.0

Browse files
Files changed (1) hide show
  1. modeling_nandi.py +8 -7
modeling_nandi.py CHANGED
@@ -354,6 +354,7 @@ class NandiModel(NandiPreTrainedModel):
354
  position_ids: torch.LongTensor | None = None,
355
  past_key_values: Cache | None = None,
356
  inputs_embeds: torch.FloatTensor | None = None,
 
357
  use_cache: bool | None = None,
358
  **kwargs: Unpack[TransformersKwargs],
359
  ) -> BaseModelOutputWithPast:
@@ -369,19 +370,20 @@ class NandiModel(NandiPreTrainedModel):
369
  repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
370
 
371
  if use_cache and past_key_values is None:
372
- # Use lazy DynamicCache (no config) so it grows to accommodate
373
- # num_hidden_layers * repeats virtual slots for layer-sharing.
374
  past_key_values = DynamicCache()
375
 
376
- if position_ids is None:
377
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
378
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
379
- position_ids = position_ids.unsqueeze(0)
 
 
380
 
381
  causal_mask = create_causal_mask(
382
  config=self.config,
383
  inputs_embeds=inputs_embeds,
384
  attention_mask=attention_mask,
 
385
  past_key_values=past_key_values,
386
  position_ids=position_ids,
387
  )
@@ -391,8 +393,6 @@ class NandiModel(NandiPreTrainedModel):
391
 
392
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
393
  for repeat_idx in range(repeats):
394
- # Each repeat gets its own virtual cache slots offset by num_hidden_layers,
395
- # so repeat 0 uses slots 0..N-1 and repeat 1 uses slots N..2N-1, etc.
396
  repeat_cache = (
397
  _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
398
  if (past_key_values is not None and repeat_idx > 0)
@@ -405,6 +405,7 @@ class NandiModel(NandiPreTrainedModel):
405
  position_ids=position_ids,
406
  past_key_values=repeat_cache,
407
  use_cache=use_cache,
 
408
  **kwargs,
409
  )
410
 
 
354
  position_ids: torch.LongTensor | None = None,
355
  past_key_values: Cache | None = None,
356
  inputs_embeds: torch.FloatTensor | None = None,
357
+ cache_position: torch.LongTensor | None = None,
358
  use_cache: bool | None = None,
359
  **kwargs: Unpack[TransformersKwargs],
360
  ) -> BaseModelOutputWithPast:
 
370
  repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
371
 
372
  if use_cache and past_key_values is None:
 
 
373
  past_key_values = DynamicCache()
374
 
375
+ if cache_position is None:
376
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
377
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
378
+
379
+ if position_ids is None:
380
+ position_ids = cache_position.unsqueeze(0)
381
 
382
  causal_mask = create_causal_mask(
383
  config=self.config,
384
  inputs_embeds=inputs_embeds,
385
  attention_mask=attention_mask,
386
+ cache_position=cache_position,
387
  past_key_values=past_key_values,
388
  position_ids=position_ids,
389
  )
 
393
 
394
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
395
  for repeat_idx in range(repeats):
 
 
396
  repeat_cache = (
397
  _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
398
  if (past_key_values is not None and repeat_idx > 0)
 
405
  position_ids=position_ids,
406
  past_key_values=repeat_cache,
407
  use_cache=use_cache,
408
+ cache_position=cache_position,
409
  **kwargs,
410
  )
411