cyrilvallez HF Staff commited on
Commit
56ee9b2
·
verified ·
1 Parent(s): 1c7950a

Fix for new versions

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +3 -1
custom_generate/generate.py CHANGED
@@ -179,7 +179,9 @@ def _dola_decoding(
179
  # keep track of which sequences are already finished
180
  batch_size, cur_length = input_ids.shape[:2]
181
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
182
- model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
 
 
183
 
184
  this_peer_finished = False
185
 
 
179
  # keep track of which sequences are already finished
180
  batch_size, cur_length = input_ids.shape[:2]
181
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
182
+ # Does not exist anymore in recent versions!
183
+ if hasattr(model, "_get_initial_cache_position"):
184
+ model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
185
 
186
  this_peer_finished = False
187