fix
Browse files- modeling_jais.py +0 -3
modeling_jais.py
CHANGED
|
@@ -813,17 +813,14 @@ class JAISModel(JAISPreTrainedModel):
|
|
| 813 |
if position_ids is not None:
|
| 814 |
position_ids = position_ids.view(-1, input_shape[-1])
|
| 815 |
|
| 816 |
-
import pdb;pdb.set_trace()
|
| 817 |
if past_key_values is None:
|
| 818 |
past_length = 0
|
| 819 |
past_key_values = tuple([None] * len(self.h))
|
| 820 |
else:
|
| 821 |
if isinstance(past_key_values, tuple):
|
| 822 |
-
import pdb;pdb.set_trace()
|
| 823 |
past_length = past_key_values[0][0].size(-2)
|
| 824 |
else:
|
| 825 |
past_length = past_key_values.get_seq_length()
|
| 826 |
-
#past_length = past_key_values[0][0].size(-2)
|
| 827 |
if position_ids is None:
|
| 828 |
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 829 |
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
|
|
| 813 |
if position_ids is not None:
|
| 814 |
position_ids = position_ids.view(-1, input_shape[-1])
|
| 815 |
|
|
|
|
| 816 |
if past_key_values is None:
|
| 817 |
past_length = 0
|
| 818 |
past_key_values = tuple([None] * len(self.h))
|
| 819 |
else:
|
| 820 |
if isinstance(past_key_values, tuple):
|
|
|
|
| 821 |
past_length = past_key_values[0][0].size(-2)
|
| 822 |
else:
|
| 823 |
past_length = past_key_values.get_seq_length()
|
|
|
|
| 824 |
if position_ids is None:
|
| 825 |
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 826 |
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|