HemanthSai7 commited on
Commit
5512fbc
·
verified ·
1 Parent(s): 4204a26

Update modeling_nandi.py

Browse files
Files changed (1) hide show
  1. modeling_nandi.py +22 -27
modeling_nandi.py CHANGED
@@ -23,20 +23,20 @@ from collections.abc import Callable
23
  import torch
24
  import torch.nn as nn
25
 
26
- from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache, DynamicLayer
28
- from transformers.generation import GenerationMixin
29
- from transformers.integrations import use_kernel_forward_from_hub
30
- from transformers.masking_utils import create_causal_mask
31
- from transformers.modeling_layers import GradientCheckpointingLayer
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
- from transformers.processing_utils import Unpack
36
- from transformers.utils import TransformersKwargs, auto_docstring
37
- from transformers.utils.deprecation import deprecate_kwarg
38
- from transformers.utils.generic import can_return_tuple, merge_with_config_defaults
39
- from transformers.utils.output_capturing import capture_outputs
40
  from .configuration_nandi import NandiConfig
41
 
42
 
@@ -189,7 +189,6 @@ class NandiAttention(nn.Module):
189
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
190
  attention_mask: torch.Tensor | None,
191
  past_key_values: Cache | None = None,
192
- cache_position: torch.LongTensor | None = None,
193
  **kwargs: Unpack[TransformersKwargs],
194
  ) -> tuple[torch.Tensor, torch.Tensor]:
195
  input_shape = hidden_states.shape[:-1]
@@ -203,8 +202,7 @@ class NandiAttention(nn.Module):
203
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
204
 
205
  if past_key_values is not None:
206
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
207
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
208
 
209
  attention_interface: Callable = eager_attention_forward
210
  if self.config._attn_implementation != "eager":
@@ -255,7 +253,6 @@ class NandiDecoderLayer(GradientCheckpointingLayer):
255
  position_ids: torch.LongTensor | None = None,
256
  past_key_values: Cache | None = None,
257
  use_cache: bool | None = False,
258
- cache_position: torch.LongTensor | None = None,
259
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
260
  **kwargs: Unpack[TransformersKwargs],
261
  ) -> torch.Tensor:
@@ -268,7 +265,6 @@ class NandiDecoderLayer(GradientCheckpointingLayer):
268
  position_ids=position_ids,
269
  past_key_values=past_key_values,
270
  use_cache=use_cache,
271
- cache_position=cache_position,
272
  position_embeddings=position_embeddings,
273
  **kwargs,
274
  )
@@ -354,7 +350,6 @@ class NandiModel(NandiPreTrainedModel):
354
  position_ids: torch.LongTensor | None = None,
355
  past_key_values: Cache | None = None,
356
  inputs_embeds: torch.FloatTensor | None = None,
357
- cache_position: torch.LongTensor | None = None,
358
  use_cache: bool | None = None,
359
  **kwargs: Unpack[TransformersKwargs],
360
  ) -> BaseModelOutputWithPast:
@@ -370,20 +365,19 @@ class NandiModel(NandiPreTrainedModel):
370
  repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
371
 
372
  if use_cache and past_key_values is None:
 
 
373
  past_key_values = DynamicCache()
374
 
375
- if cache_position is None:
376
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
377
- cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
378
-
379
  if position_ids is None:
380
- position_ids = cache_position.unsqueeze(0)
 
 
381
 
382
  causal_mask = create_causal_mask(
383
  config=self.config,
384
  inputs_embeds=inputs_embeds,
385
  attention_mask=attention_mask,
386
- cache_position=cache_position,
387
  past_key_values=past_key_values,
388
  position_ids=position_ids,
389
  )
@@ -393,6 +387,8 @@ class NandiModel(NandiPreTrainedModel):
393
 
394
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
395
  for repeat_idx in range(repeats):
 
 
396
  repeat_cache = (
397
  _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
398
  if (past_key_values is not None and repeat_idx > 0)
@@ -405,7 +401,6 @@ class NandiModel(NandiPreTrainedModel):
405
  position_ids=position_ids,
406
  past_key_values=repeat_cache,
407
  use_cache=use_cache,
408
- cache_position=cache_position,
409
  **kwargs,
410
  )
411
 
 
23
  import torch
24
  import torch.nn as nn
25
 
26
+ from ...activations import ACT2FN
27
+ from ...cache_utils import Cache, DynamicCache, DynamicLayer
28
+ from ...generation import GenerationMixin
29
+ from ...integrations import use_kernel_forward_from_hub
30
+ from ...masking_utils import create_causal_mask
31
+ from ...modeling_layers import GradientCheckpointingLayer
32
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from ...processing_utils import Unpack
36
+ from ...utils import TransformersKwargs, auto_docstring
37
+ from ...utils.deprecation import deprecate_kwarg
38
+ from ...utils.generic import can_return_tuple, merge_with_config_defaults
39
+ from ...utils.output_capturing import capture_outputs
40
  from .configuration_nandi import NandiConfig
41
 
42
 
 
189
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
190
  attention_mask: torch.Tensor | None,
191
  past_key_values: Cache | None = None,
 
192
  **kwargs: Unpack[TransformersKwargs],
193
  ) -> tuple[torch.Tensor, torch.Tensor]:
194
  input_shape = hidden_states.shape[:-1]
 
202
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
203
 
204
  if past_key_values is not None:
205
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
206
 
207
  attention_interface: Callable = eager_attention_forward
208
  if self.config._attn_implementation != "eager":
 
253
  position_ids: torch.LongTensor | None = None,
254
  past_key_values: Cache | None = None,
255
  use_cache: bool | None = False,
 
256
  position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
257
  **kwargs: Unpack[TransformersKwargs],
258
  ) -> torch.Tensor:
 
265
  position_ids=position_ids,
266
  past_key_values=past_key_values,
267
  use_cache=use_cache,
 
268
  position_embeddings=position_embeddings,
269
  **kwargs,
270
  )
 
350
  position_ids: torch.LongTensor | None = None,
351
  past_key_values: Cache | None = None,
352
  inputs_embeds: torch.FloatTensor | None = None,
 
353
  use_cache: bool | None = None,
354
  **kwargs: Unpack[TransformersKwargs],
355
  ) -> BaseModelOutputWithPast:
 
365
  repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
366
 
367
  if use_cache and past_key_values is None:
368
+ # Use lazy DynamicCache (no config) so it grows to accommodate
369
+ # num_hidden_layers * repeats virtual slots for layer-sharing.
370
  past_key_values = DynamicCache()
371
 
 
 
 
 
372
  if position_ids is None:
373
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
374
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
375
+ position_ids = position_ids.unsqueeze(0)
376
 
377
  causal_mask = create_causal_mask(
378
  config=self.config,
379
  inputs_embeds=inputs_embeds,
380
  attention_mask=attention_mask,
 
381
  past_key_values=past_key_values,
382
  position_ids=position_ids,
383
  )
 
387
 
388
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
389
  for repeat_idx in range(repeats):
390
+ # Each repeat gets its own virtual cache slots offset by num_hidden_layers,
391
+ # so repeat 0 uses slots 0..N-1 and repeat 1 uses slots N..2N-1, etc.
392
  repeat_cache = (
393
  _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
394
  if (past_key_values is not None and repeat_idx > 0)
 
401
  position_ids=position_ids,
402
  past_key_values=repeat_cache,
403
  use_cache=use_cache,
 
404
  **kwargs,
405
  )
406