vikhyatk err805 commited on
Commit
5112966
·
1 Parent(s): a21cc17

Fix runtime buffers after load (#37)

Browse files

- Fix runtime buffers after load (8505639715a6ac8faed98a9e727b1259a51a4c31)
- Fix runtime buffers after load (4d8ba2a73affbbc2a185c96e128c277f3a21447b)


Co-authored-by: Ethan Reid <err805@users.noreply.huggingface.co>

Files changed (2) hide show
  1. hf_moondream.py +7 -0
  2. moondream.py +21 -0
hf_moondream.py CHANGED
@@ -45,6 +45,13 @@ class HfMoondream(PreTrainedModel):
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
 
 
 
 
 
 
 
48
  def _setup_caches(self):
49
  if not self._is_kv_cache_setup:
50
  self.model._setup_caches()
 
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
48
+ @classmethod
49
+ def from_pretrained(cls, *args, **kwargs):
50
+ output = super().from_pretrained(*args, **kwargs)
51
+ model = output[0] if isinstance(output, tuple) else output
52
+ model.model._refresh_runtime_buffers()
53
+ return output
54
+
55
  def _setup_caches(self):
56
  if not self._is_kv_cache_setup:
57
  self.model._setup_caches()
moondream.py CHANGED
@@ -22,6 +22,7 @@ from .region import (
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
 
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
@@ -171,6 +172,26 @@ class MoondreamModel(nn.Module):
171
  )
172
  return self._point_gen_indices
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def _setup_caches(self):
175
  c = self.config.text
176
  for b in self.text.blocks:
 
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
25
+ from .rope import precompute_freqs_cis
26
  from .utils import remove_outlier_points
27
 
28
  ImageEncodingSettings = TypedDict(
 
172
  )
173
  return self._point_gen_indices
174
 
175
+ def _refresh_runtime_buffers(self):
176
+ attn_mask = torch.tril(
177
+ torch.ones(
178
+ 1,
179
+ 1,
180
+ self.config.text.max_context,
181
+ self.config.text.max_context,
182
+ dtype=torch.bool,
183
+ device=self.device,
184
+ )
185
+ )
186
+ patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
187
+ prefix_attn_len = 1 + patch_w**2
188
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
189
+ self.attn_mask = attn_mask
190
+ self.text.freqs_cis = precompute_freqs_cis(
191
+ self.config.text.dim // (2 * self.config.text.n_heads),
192
+ self.config.text.max_context,
193
+ ).to(device=self.device)
194
+
195
  def _setup_caches(self):
196
  c = self.config.text
197
  for b in self.text.blocks: