SixOpen commited on
Commit
7d6b779
·
verified ·
1 Parent(s): ed1528a

Update modeling_hare.py

Browse files
Files changed (1) hide show
  1. modeling_hare.py +14 -4
modeling_hare.py CHANGED
@@ -1,9 +1,13 @@
1
  import json
 
2
  from pathlib import Path
3
 
4
  import torch
5
- from transformers import AutoModel, AutoConfig, PreTrainedModel
6
- from transformers.modeling_outputs import BaseModelOutput
 
 
 
7
 
8
  from .configuration_hare import HareConfig
9
  from .birwkv7 import BiRWKV7Layer, init_from_attention
@@ -46,14 +50,20 @@ class HareModel(PreTrainedModel):
46
 
47
  def __init__(self, config):
48
  super().__init__(config)
49
- base_config = AutoConfig.from_pretrained(
50
- "answerdotai/ModernBERT-base",
51
  hidden_size=config.hidden_size,
52
  num_attention_heads=config.num_attention_heads,
53
  num_hidden_layers=config.num_hidden_layers,
54
  intermediate_size=config.intermediate_size,
55
  vocab_size=config.vocab_size,
56
  max_position_embeddings=config.max_position_embeddings,
 
 
 
 
 
 
 
57
  )
58
  self.inner_model = AutoModel.from_config(base_config)
59
 
 
1
  import json
2
+ import logging
3
  from pathlib import Path
4
 
5
  import torch
6
+ from transformers import AutoModel, PreTrainedModel
7
+ from transformers import ModernBertConfig
8
+
9
+ for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]:
10
+ logging.getLogger(_logger_name).setLevel(logging.ERROR)
11
 
12
  from .configuration_hare import HareConfig
13
  from .birwkv7 import BiRWKV7Layer, init_from_attention
 
50
 
51
  def __init__(self, config):
52
  super().__init__(config)
53
+ base_config = ModernBertConfig(
 
54
  hidden_size=config.hidden_size,
55
  num_attention_heads=config.num_attention_heads,
56
  num_hidden_layers=config.num_hidden_layers,
57
  intermediate_size=config.intermediate_size,
58
  vocab_size=config.vocab_size,
59
  max_position_embeddings=config.max_position_embeddings,
60
+ pad_token_id=config.pad_token_id,
61
+ bos_token_id=config.bos_token_id,
62
+ eos_token_id=config.eos_token_id,
63
+ cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id),
64
+ sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id),
65
+ global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3),
66
+ local_attention=getattr(config, 'local_attention', 128),
67
  )
68
  self.inner_model = AutoModel.from_config(base_config)
69