fixes DynamicCache handling in model conversion
Browse files- __init__.py +4 -0
- modeling_prisma.py +17 -10
__init__.py
CHANGED
|
@@ -10,6 +10,8 @@ from .model import CircuitTransformer, count_parameters
|
|
| 10 |
from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
|
| 11 |
from .data import get_tokenizer, load_data, create_dataloader, TextDataset
|
| 12 |
from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
|
|
|
|
|
|
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"CircuitConfig",
|
|
@@ -25,4 +27,6 @@ __all__ = [
|
|
| 25 |
"G2LU_GraftedModel",
|
| 26 |
"G2LU_MLP",
|
| 27 |
"load_g2lu_model",
|
|
|
|
|
|
|
| 28 |
]
|
|
|
|
| 10 |
from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
|
| 11 |
from .data import get_tokenizer, load_data, create_dataloader, TextDataset
|
| 12 |
from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
|
| 13 |
+
from .configuration_prisma import PrismaConfig
|
| 14 |
+
from .modeling_prisma import PrismaForCausalLM
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
"CircuitConfig",
|
|
|
|
| 27 |
"G2LU_GraftedModel",
|
| 28 |
"G2LU_MLP",
|
| 29 |
"load_g2lu_model",
|
| 30 |
+
"PrismaConfig",
|
| 31 |
+
"PrismaForCausalLM",
|
| 32 |
]
|
modeling_prisma.py
CHANGED
|
@@ -102,15 +102,15 @@ class PrismaForCausalLM(PreTrainedModel):
|
|
| 102 |
# Convert HF DynamicCache to our list-of-tuples format
|
| 103 |
past_kv_list = None
|
| 104 |
if past_key_values is not None:
|
| 105 |
-
if
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
# Compute word positions if WoRPE is enabled
|
| 116 |
word_positions = None
|
|
@@ -142,7 +142,7 @@ class PrismaForCausalLM(PreTrainedModel):
|
|
| 142 |
|
| 143 |
# Convert our list-of-tuples back to DynamicCache
|
| 144 |
new_cache = None
|
| 145 |
-
if output.get("past_kv") is not None:
|
| 146 |
from transformers.cache_utils import DynamicCache
|
| 147 |
new_cache = DynamicCache()
|
| 148 |
for layer_idx, (k, v) in enumerate(output["past_kv"]):
|
|
@@ -163,7 +163,14 @@ class PrismaForCausalLM(PreTrainedModel):
|
|
| 163 |
def prepare_inputs_for_generation(
|
| 164 |
self, input_ids, past_key_values=None, **kwargs
|
| 165 |
):
|
|
|
|
|
|
|
| 166 |
if past_key_values is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
input_ids = input_ids[:, -1:]
|
| 168 |
|
| 169 |
return {
|
|
|
|
| 102 |
# Convert HF DynamicCache to our list-of-tuples format
|
| 103 |
past_kv_list = None
|
| 104 |
if past_key_values is not None:
|
| 105 |
+
# Check if cache has actual content (not just pre-allocated empty layers)
|
| 106 |
+
has_content = False
|
| 107 |
+
if isinstance(past_key_values, (list, tuple)):
|
| 108 |
+
has_content = len(past_key_values) > 0
|
| 109 |
+
past_kv_list = past_key_values if has_content else None
|
| 110 |
+
elif hasattr(past_key_values, 'get_seq_length'):
|
| 111 |
+
has_content = past_key_values.get_seq_length() > 0
|
| 112 |
+
if has_content:
|
| 113 |
+
past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]
|
| 114 |
|
| 115 |
# Compute word positions if WoRPE is enabled
|
| 116 |
word_positions = None
|
|
|
|
| 142 |
|
| 143 |
# Convert our list-of-tuples back to DynamicCache
|
| 144 |
new_cache = None
|
| 145 |
+
if use_cache and output.get("past_kv") is not None:
|
| 146 |
from transformers.cache_utils import DynamicCache
|
| 147 |
new_cache = DynamicCache()
|
| 148 |
for layer_idx, (k, v) in enumerate(output["past_kv"]):
|
|
|
|
| 163 |
def prepare_inputs_for_generation(
|
| 164 |
self, input_ids, past_key_values=None, **kwargs
|
| 165 |
):
|
| 166 |
+
# Only trim to last token if cache has actual KV content
|
| 167 |
+
has_cache = False
|
| 168 |
if past_key_values is not None:
|
| 169 |
+
if hasattr(past_key_values, 'get_seq_length'):
|
| 170 |
+
has_cache = past_key_values.get_seq_length() > 0
|
| 171 |
+
elif isinstance(past_key_values, (list, tuple)):
|
| 172 |
+
has_cache = len(past_key_values) > 0
|
| 173 |
+
if has_cache:
|
| 174 |
input_ids = input_ids[:, -1:]
|
| 175 |
|
| 176 |
return {
|