Fix for new versions
Browse files
custom_generate/generate.py
CHANGED
|
@@ -179,7 +179,9 @@ def _dola_decoding(
|
|
| 179 |
# keep track of which sequences are already finished
|
| 180 |
batch_size, cur_length = input_ids.shape[:2]
|
| 181 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
|
| 184 |
this_peer_finished = False
|
| 185 |
|
|
|
|
| 179 |
# keep track of which sequences are already finished
|
| 180 |
batch_size, cur_length = input_ids.shape[:2]
|
| 181 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 182 |
+
# Does not exist anymore in recent versions!
|
| 183 |
+
if hasattr(model, "_get_initial_cache_position"):
|
| 184 |
+
model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
|
| 185 |
|
| 186 |
this_peer_finished = False
|
| 187 |
|