y3i12 commited on
Commit
c855cdb
·
1 Parent(s): a2df0cc

fixes DynamicCache handling in model conversion

Browse files
Files changed (2) hide show
  1. __init__.py +4 -0
  2. modeling_prisma.py +17 -10
__init__.py CHANGED
@@ -10,6 +10,8 @@ from .model import CircuitTransformer, count_parameters
10
  from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
11
  from .data import get_tokenizer, load_data, create_dataloader, TextDataset
12
  from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
 
 
13
 
14
  __all__ = [
15
  "CircuitConfig",
@@ -25,4 +27,6 @@ __all__ = [
25
  "G2LU_GraftedModel",
26
  "G2LU_MLP",
27
  "load_g2lu_model",
 
 
28
  ]
 
10
  from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
11
  from .data import get_tokenizer, load_data, create_dataloader, TextDataset
12
  from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
13
+ from .configuration_prisma import PrismaConfig
14
+ from .modeling_prisma import PrismaForCausalLM
15
 
16
  __all__ = [
17
  "CircuitConfig",
 
27
  "G2LU_GraftedModel",
28
  "G2LU_MLP",
29
  "load_g2lu_model",
30
+ "PrismaConfig",
31
+ "PrismaForCausalLM",
32
  ]
modeling_prisma.py CHANGED
@@ -102,15 +102,15 @@ class PrismaForCausalLM(PreTrainedModel):
102
  # Convert HF DynamicCache to our list-of-tuples format
103
  past_kv_list = None
104
  if past_key_values is not None:
105
- if hasattr(past_key_values, 'key_cache'):
106
- # HF DynamicCache
107
- if len(past_key_values) > 0:
108
- past_kv_list = [
109
- (past_key_values.key_cache[i], past_key_values.value_cache[i])
110
- for i in range(len(past_key_values))
111
- ]
112
- elif isinstance(past_key_values, (list, tuple)):
113
- past_kv_list = past_key_values
114
 
115
  # Compute word positions if WoRPE is enabled
116
  word_positions = None
@@ -142,7 +142,7 @@ class PrismaForCausalLM(PreTrainedModel):
142
 
143
  # Convert our list-of-tuples back to DynamicCache
144
  new_cache = None
145
- if output.get("past_kv") is not None:
146
  from transformers.cache_utils import DynamicCache
147
  new_cache = DynamicCache()
148
  for layer_idx, (k, v) in enumerate(output["past_kv"]):
@@ -163,7 +163,14 @@ class PrismaForCausalLM(PreTrainedModel):
163
  def prepare_inputs_for_generation(
164
  self, input_ids, past_key_values=None, **kwargs
165
  ):
 
 
166
  if past_key_values is not None:
 
 
 
 
 
167
  input_ids = input_ids[:, -1:]
168
 
169
  return {
 
102
  # Convert HF DynamicCache to our list-of-tuples format
103
  past_kv_list = None
104
  if past_key_values is not None:
105
+ # Check if cache has actual content (not just pre-allocated empty layers)
106
+ has_content = False
107
+ if isinstance(past_key_values, (list, tuple)):
108
+ has_content = len(past_key_values) > 0
109
+ past_kv_list = past_key_values if has_content else None
110
+ elif hasattr(past_key_values, 'get_seq_length'):
111
+ has_content = past_key_values.get_seq_length() > 0
112
+ if has_content:
113
+ past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]
114
 
115
  # Compute word positions if WoRPE is enabled
116
  word_positions = None
 
142
 
143
  # Convert our list-of-tuples back to DynamicCache
144
  new_cache = None
145
+ if use_cache and output.get("past_kv") is not None:
146
  from transformers.cache_utils import DynamicCache
147
  new_cache = DynamicCache()
148
  for layer_idx, (k, v) in enumerate(output["past_kv"]):
 
163
  def prepare_inputs_for_generation(
164
  self, input_ids, past_key_values=None, **kwargs
165
  ):
166
+ # Only trim to last token if cache has actual KV content
167
+ has_cache = False
168
  if past_key_values is not None:
169
+ if hasattr(past_key_values, 'get_seq_length'):
170
+ has_cache = past_key_values.get_seq_length() > 0
171
+ elif isinstance(past_key_values, (list, tuple)):
172
+ has_cache = len(past_key_values) > 0
173
+ if has_cache:
174
  input_ids = input_ids[:, -1:]
175
 
176
  return {