err805 commited on
Commit
fcf7123
·
1 Parent(s): a21cc17

Fix runtime buffers after load

Browse files
Files changed (2) hide show
  1. hf_moondream.py +7 -0
  2. moondream.py +24 -9
hf_moondream.py CHANGED
@@ -45,6 +45,13 @@ class HfMoondream(PreTrainedModel):
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
 
 
 
 
 
 
 
48
  def _setup_caches(self):
49
  if not self._is_kv_cache_setup:
50
  self.model._setup_caches()
 
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
48
+ @classmethod
49
+ def from_pretrained(cls, *args, **kwargs):
50
+ output = super().from_pretrained(*args, **kwargs)
51
+ model = output[0] if isinstance(output, tuple) else output
52
+ model.model._refresh_runtime_buffers()
53
+ return output
54
+
55
  def _setup_caches(self):
56
  if not self._is_kv_cache_setup:
57
  self.model._setup_caches()
moondream.py CHANGED
@@ -22,6 +22,7 @@ from .region import (
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
 
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
@@ -131,15 +132,7 @@ class MoondreamModel(nn.Module):
131
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
132
  )
133
 
134
- attn_mask = torch.tril(
135
- torch.ones(
136
- 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
137
- )
138
- )
139
- patch_w = config.vision.crop_size // config.vision.enc_patch_size
140
- prefix_attn_len = 1 + patch_w**2
141
- attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
142
- self.register_buffer("attn_mask", attn_mask, persistent=False)
143
 
144
  self.use_flex_decoding = True
145
  self._causal_block_mask = None
@@ -171,6 +164,28 @@ class MoondreamModel(nn.Module):
171
  )
172
  return self._point_gen_indices
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def _setup_caches(self):
175
  c = self.config.text
176
  for b in self.text.blocks:
 
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
25
+ from .rope import precompute_freqs_cis
26
  from .utils import remove_outlier_points
27
 
28
  ImageEncodingSettings = TypedDict(
 
132
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
133
  )
134
 
135
+ self.register_buffer("attn_mask", self._build_attn_mask(), persistent=False)
 
 
 
 
 
 
 
 
136
 
137
  self.use_flex_decoding = True
138
  self._causal_block_mask = None
 
164
  )
165
  return self._point_gen_indices
166
 
167
+ def _build_attn_mask(self):
168
+ attn_mask = torch.tril(
169
+ torch.ones(
170
+ 1,
171
+ 1,
172
+ self.config.text.max_context,
173
+ self.config.text.max_context,
174
+ dtype=torch.bool,
175
+ )
176
+ )
177
+ patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
178
+ prefix_attn_len = 1 + patch_w**2
179
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
180
+ return attn_mask
181
+
182
+ def _refresh_runtime_buffers(self):
183
+ self.attn_mask = self._build_attn_mask().to(device=self.device)
184
+ self.text.freqs_cis = precompute_freqs_cis(
185
+ self.config.text.dim // (2 * self.config.text.n_heads),
186
+ self.config.text.max_context,
187
+ ).to(device=self.device)
188
+
189
  def _setup_caches(self):
190
  c = self.config.text
191
  for b in self.text.blocks: