| from dataclasses import dataclass, field
|
| from typing import Dict, List, Optional
|
|
|
|
|
| @dataclass(frozen=True)
|
| class TextConfig:
|
| dim: int = 2048
|
| ff_dim: int = 8192
|
| n_layers: int = 24
|
| vocab_size: int = 51200
|
| max_context: int = 2048
|
| n_heads: int = 32
|
| n_kv_heads: int = 32
|
| prefix_attn: int = 730
|
| group_size: Optional[int] = None
|
|
|
|
|
| @dataclass(frozen=True)
|
| class VisionConfig:
|
| enc_dim: int = 1152
|
| enc_patch_size: int = 14
|
| enc_n_layers: int = 27
|
| enc_ff_dim: int = 4304
|
| enc_n_heads: int = 16
|
| proj_out_dim: int = 2048
|
| crop_size: int = 378
|
| in_channels: int = 3
|
| max_crops: int = 12
|
| overlap_margin: int = 4
|
| proj_inner_dim: int = 8192
|
|
|
|
|
| @dataclass(frozen=True)
|
| class RegionConfig:
|
| dim: int = 2048
|
| coord_feat_dim: int = 256
|
| coord_out_dim: int = 1024
|
| size_feat_dim: int = 512
|
| size_out_dim: int = 2048
|
| inner_dim: int = 8192
|
| group_size: Optional[int] = None
|
|
|
|
|
| @dataclass(frozen=True)
|
| class TokenizerConfig:
|
| bos_id: int = 0
|
| eos_id: int = 0
|
| answer_id: int = 3
|
| thinking_id: int = 4
|
| coord_id: int = 5
|
| size_id: int = 6
|
| start_ground_points_id: int = 7
|
| end_ground_id: int = 9
|
| templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
|
| default_factory=lambda: {
|
| "caption": {
|
| "short": [1, 32708, 2, 12492, 3],
|
| "normal": [1, 32708, 2, 6382, 3],
|
| "long": [1, 32708, 2, 4059, 3],
|
| },
|
| "query": {"prefix": [1, 15381, 2], "suffix": [3]},
|
| "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
|
| "point": {"prefix": [1, 2581, 2], "suffix": [3]},
|
| }
|
| )
|
|
|
|
|
| @dataclass(frozen=True)
|
| class MoondreamConfig:
|
| text: TextConfig = TextConfig()
|
| vision: VisionConfig = VisionConfig()
|
| region: RegionConfig = RegionConfig()
|
| tokenizer: TokenizerConfig = TokenizerConfig()
|
|
|
| @classmethod
|
| def from_dict(cls, config_dict: dict):
|
| text_config = TextConfig(**config_dict.get("text", {}))
|
| vision_config = VisionConfig(**config_dict.get("vision", {}))
|
| region_config = RegionConfig(**config_dict.get("region", {}))
|
| tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
|
| return cls(
|
| text=text_config,
|
| vision=vision_config,
|
| region=region_config,
|
| tokenizer=tokenizer_config,
|
| )
|
|
|
| def to_dict(self):
|
| return {
|
| "text": self.text.__dict__,
|
| "vision": self.vision.__dict__,
|
| "region": self.region.__dict__,
|
| "tokenizer": self.tokenizer.__dict__,
|
| }
|
|
|