JustinAngel commited on
Commit
e411988
·
verified ·
1 Parent(s): f85babe

Upload modeling_workshop_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_workshop_gpt.py +8 -9
modeling_workshop_gpt.py CHANGED
@@ -36,19 +36,18 @@ class RotaryPositionalEmbeddings(nn.Module):
36
  self.dim = dim
37
  self.max_seq_len = max_seq_len
38
  self.base = base
39
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
40
- self.register_buffer("theta", theta, persistent=False)
41
- self._build_cache(max_seq_len)
42
 
43
- def _build_cache(self, seq_len):
44
- seq = torch.arange(seq_len, device=self.theta.device)
45
- freqs = torch.outer(seq, self.theta)
46
- self.register_buffer("cache", torch.stack([freqs.cos(), freqs.sin()], dim=-1), persistent=False)
 
47
 
48
  def forward(self, x, *, input_pos=None):
49
  seq_len = x.shape[-2]
50
- if seq_len > self.cache.shape[0]:
51
- self._build_cache(seq_len)
52
  cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
53
  x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
54
  cos, sin = cache.unbind(-1)
 
36
  self.dim = dim
37
  self.max_seq_len = max_seq_len
38
  self.base = base
39
+ self.cache = None
 
 
40
 
41
+ def _build_cache(self, seq_len, device):
42
+ theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
43
+ seq = torch.arange(seq_len, device=device)
44
+ freqs = torch.outer(seq, theta)
45
+ self.cache = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
46
 
47
  def forward(self, x, *, input_pos=None):
48
  seq_len = x.shape[-2]
49
+ if self.cache is None or self.cache.shape[0] < seq_len or self.cache.device != x.device:
50
+ self._build_cache(max(seq_len, self.max_seq_len), x.device)
51
  cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
52
  x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
53
  cos, sin = cache.unbind(-1)