| from transformers.configuration_utils import PretrainedConfig |
| import sys |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| LlamaConfig, |
| LlamaForCausalLM, |
| PreTrainedModel, |
| ) |
| from .attrdict_config import AttrDict |
|
|
| class VisionConfig(PretrainedConfig): |
| model_type = "vision" |
| cls: str = "" |
| params: AttrDict = {} |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.cls = kwargs.get("cls", "") |
| if not isinstance(self.cls, str): |
| self.cls = self.cls.__name__ |
|
|
| self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
| class AlignerConfig(PretrainedConfig): |
| model_type = "aligner" |
| cls: str = "" |
| params: AttrDict = {} |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.cls = kwargs.get("cls", "") |
| if not isinstance(self.cls, str): |
| self.cls = self.cls.__name__ |
|
|
| self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
| class GenVisionConfig(PretrainedConfig): |
| model_type = "gen_vision" |
| cls: str = "" |
| params: AttrDict = {} |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.cls = kwargs.get("cls", "") |
| if not isinstance(self.cls, str): |
| self.cls = self.cls.__name__ |
|
|
| self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
| class GenAlignerConfig(PretrainedConfig): |
| model_type = "gen_aligner" |
| cls: str = "" |
| params: AttrDict = {} |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.cls = kwargs.get("cls", "") |
| if not isinstance(self.cls, str): |
| self.cls = self.cls.__name__ |
|
|
| self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
| class GenHeadConfig(PretrainedConfig): |
| model_type = "gen_head" |
| cls: str = "" |
| params: AttrDict = {} |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.cls = kwargs.get("cls", "") |
| if not isinstance(self.cls, str): |
| self.cls = self.cls.__name__ |
|
|
| self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
| class MultiModalityConfig(PretrainedConfig): |
| model_type = "multi_modality" |
| vision_config: VisionConfig |
| aligner_config: AlignerConfig |
|
|
| gen_vision_config: GenVisionConfig |
| gen_aligner_config: GenAlignerConfig |
| gen_head_config: GenHeadConfig |
|
|
| language_config: LlamaConfig |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| vision_config = kwargs.get("vision_config", {}) |
| self.vision_config = VisionConfig(**vision_config) |
|
|
| aligner_config = kwargs.get("aligner_config", {}) |
| self.aligner_config = AlignerConfig(**aligner_config) |
|
|
| gen_vision_config = kwargs.get("gen_vision_config", {}) |
| self.gen_vision_config = GenVisionConfig(**gen_vision_config) |
|
|
| gen_aligner_config = kwargs.get("gen_aligner_config", {}) |
| self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) |
|
|
| gen_head_config = kwargs.get("gen_head_config", {}) |
| self.gen_head_config = GenHeadConfig(**gen_head_config) |
|
|
| language_config = kwargs.get("language_config", {}) |
| if isinstance(language_config, LlamaConfig): |
| self.language_config = language_config |
| else: |
| self.language_config = LlamaConfig(**language_config) |
|
|