qubitpage commited on
Commit
3ba2208
·
verified ·
1 Parent(s): 90885ba

Upload hf_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hf_model.py +29 -11
hf_model.py CHANGED
@@ -357,18 +357,36 @@ class SentinelBrainForCausalLM(PreTrainedModel, GenerationMixin):
357
  B, T = input_ids.shape
358
  x = self.tok_emb(input_ids)
359
 
360
- # Determine if we have valid past KV caches
 
361
  has_past = False
362
  past_len = 0
363
- if past_key_values is not None and len(past_key_values) > 0:
364
- first = past_key_values[0]
365
- if first is not None:
366
- if isinstance(first, (tuple, list)) and len(first) > 0 and first[0] is not None:
367
- has_past = True
368
- past_len = first[0].shape[2]
369
- elif hasattr(first, 'shape'):
370
- has_past = True
371
- past_len = first.shape[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  rope_cos, rope_sin = self.rope(past_len + T)
374
  rope_cos = rope_cos[:, :, past_len:past_len + T].to(x.device)
@@ -379,7 +397,7 @@ class SentinelBrainForCausalLM(PreTrainedModel, GenerationMixin):
379
  total_z = 0.0
380
 
381
  for i, layer in enumerate(self.layers):
382
- kv_cache = past_key_values[i] if has_past else None
383
  x, new_kv, aux, z = layer(x, rope_cos, rope_sin, kv_cache=kv_cache)
384
  new_kv_caches.append(new_kv)
385
  total_aux += aux
 
357
  B, T = input_ids.shape
358
  x = self.tok_emb(input_ids)
359
 
360
+ # Determine if we have valid past KV caches.
361
+ # Support: list-of-tuples (legacy), tuple-of-tuples, and DynamicCache (new transformers).
362
  has_past = False
363
  past_len = 0
364
+ _legacy_past = None # normalized to list-of-tuples form
365
+
366
+ if past_key_values is not None:
367
+ # New API: DynamicCache or similar Cache object
368
+ if hasattr(past_key_values, "to_legacy_cache"):
369
+ try:
370
+ legacy = past_key_values.to_legacy_cache()
371
+ if legacy is not None and len(legacy) > 0:
372
+ _legacy_past = list(legacy)
373
+ first = _legacy_past[0]
374
+ if first is not None and len(first) > 0 and first[0] is not None:
375
+ has_past = True
376
+ past_len = first[0].shape[2]
377
+ except Exception:
378
+ pass
379
+ # Legacy API: list/tuple of (k, v) tuples
380
+ elif isinstance(past_key_values, (list, tuple)) and len(past_key_values) > 0:
381
+ _legacy_past = list(past_key_values)
382
+ first = _legacy_past[0]
383
+ if first is not None:
384
+ if isinstance(first, (tuple, list)) and len(first) > 0 and first[0] is not None:
385
+ has_past = True
386
+ past_len = first[0].shape[2]
387
+ elif hasattr(first, "shape"):
388
+ has_past = True
389
+ past_len = first.shape[2]
390
 
391
  rope_cos, rope_sin = self.rope(past_len + T)
392
  rope_cos = rope_cos[:, :, past_len:past_len + T].to(x.device)
 
397
  total_z = 0.0
398
 
399
  for i, layer in enumerate(self.layers):
400
+ kv_cache = _legacy_past[i] if (has_past and _legacy_past is not None and i < len(_legacy_past)) else None
401
  x, new_kv, aux, z = layer(x, rope_cos, rope_sin, kv_cache=kv_cache)
402
  new_kv_caches.append(new_kv)
403
  total_aux += aux