Upload modeling_workshop_gpt.py
Browse files- modeling_workshop_gpt.py +8 -9
modeling_workshop_gpt.py
CHANGED
|
@@ -36,19 +36,18 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|
| 36 |
self.dim = dim
|
| 37 |
self.max_seq_len = max_seq_len
|
| 38 |
self.base = base
|
| 39 |
-
|
| 40 |
-
self.register_buffer("theta", theta, persistent=False)
|
| 41 |
-
self._build_cache(max_seq_len)
|
| 42 |
|
| 43 |
-
def _build_cache(self, seq_len):
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
def forward(self, x, *, input_pos=None):
|
| 49 |
seq_len = x.shape[-2]
|
| 50 |
-
if
|
| 51 |
-
self._build_cache(seq_len)
|
| 52 |
cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
|
| 53 |
x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
|
| 54 |
cos, sin = cache.unbind(-1)
|
|
|
|
| 36 |
self.dim = dim
|
| 37 |
self.max_seq_len = max_seq_len
|
| 38 |
self.base = base
|
| 39 |
+
self.cache = None
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def _build_cache(self, seq_len, device):
|
| 42 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
|
| 43 |
+
seq = torch.arange(seq_len, device=device)
|
| 44 |
+
freqs = torch.outer(seq, theta)
|
| 45 |
+
self.cache = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
|
| 46 |
|
| 47 |
def forward(self, x, *, input_pos=None):
|
| 48 |
seq_len = x.shape[-2]
|
| 49 |
+
if self.cache is None or self.cache.shape[0] < seq_len or self.cache.device != x.device:
|
| 50 |
+
self._build_cache(max(seq_len, self.max_seq_len), x.device)
|
| 51 |
cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
|
| 52 |
x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
|
| 53 |
cos, sin = cache.unbind(-1)
|