Commit ·
41d8e19
1
Parent(s): e19e1b1
Upload 2 files
Browse files
performer_pytorch/autoregressive_wrapper.py
CHANGED
|
@@ -4,8 +4,6 @@ from torch import nn
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
|
| 7 |
-
import pdb
|
| 8 |
-
|
| 9 |
|
| 10 |
def exists(val):
|
| 11 |
return val is not None
|
|
@@ -108,6 +106,4 @@ class AutoregressiveWrapper(nn.Module):
|
|
| 108 |
|
| 109 |
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
|
| 110 |
|
| 111 |
-
#pdb.set_trace()
|
| 112 |
-
|
| 113 |
return loss
|
|
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def exists(val):
|
| 9 |
return val is not None
|
|
|
|
| 106 |
|
| 107 |
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
|
| 108 |
|
|
|
|
|
|
|
| 109 |
return loss
|
performer_pytorch/performer_pytorch.py
CHANGED
|
@@ -12,7 +12,6 @@ from contextlib import contextmanager
|
|
| 12 |
from local_attention import LocalAttention
|
| 13 |
from performer_pytorch.reversible import ReversibleSequence, SequentialSequence
|
| 14 |
|
| 15 |
-
import pdb
|
| 16 |
|
| 17 |
try:
|
| 18 |
from apex import amp
|
|
@@ -605,7 +604,6 @@ class PerformerLM(nn.Module):
|
|
| 605 |
b, n, device = *x.shape, x.device
|
| 606 |
assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}'
|
| 607 |
|
| 608 |
-
#pdb.set_trace()
|
| 609 |
# token and positional embedding
|
| 610 |
x = self.token_emb(x)
|
| 611 |
if output_attentions:
|
|
|
|
| 12 |
from local_attention import LocalAttention
|
| 13 |
from performer_pytorch.reversible import ReversibleSequence, SequentialSequence
|
| 14 |
|
|
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
from apex import amp
|
|
|
|
| 604 |
b, n, device = *x.shape, x.device
|
| 605 |
assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}'
|
| 606 |
|
|
|
|
| 607 |
# token and positional embedding
|
| 608 |
x = self.token_emb(x)
|
| 609 |
if output_attentions:
|