| from transformers import AutoConfig |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.models.dinov2.configuration_dinov2 import Dinov2Config |
|
|
|
|
| class VisionConfig(PretrainedConfig): |
| def __init__( |
| self, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| @staticmethod |
| def from_exp_config(vision_config: dict): |
|
|
| model_type = vision_config["model_type"] |
|
|
| if model_type in [ |
| "siglip_vision_model", |
| "clip_vision_model", |
| "dinov2", |
| "sam", |
| "raddino", |
| ]: |
| config = AutoConfig.from_pretrained( |
| vision_config["pretrained_name_or_path"] |
| ) |
| config = config.to_dict() |
| vision_config.update(config) |
| elif model_type == "xrayclip": |
| config = AutoConfig.from_pretrained( |
| vision_config["pretrained_name_or_path"] |
| ) |
| config = config.to_dict() |
| config["model_type"] = "xrayclip" |
| vision_config.update(config) |
| elif model_type == "biomedclip": |
| pass |
| elif model_type == "m3ae": |
| pass |
|
|
| else: |
| raise NotImplementedError() |
|
|
| vision_config = VisionConfig(**vision_config) |
|
|
| return vision_config |
|
|
|
|
| class TextConfig(PretrainedConfig): |
| def __init__( |
| self, |
| model_type, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.model_type = model_type |
|
|
| @staticmethod |
| def from_exp_config( |
| text_config: dict, |
| ): |
| model_type = text_config["model_type"] |
|
|
| if model_type in [ |
| "siglip_text_model", |
| "clip_text_model", |
| "mpnet", |
| "biomedclip", |
| "bioclinicalmpbert", |
| ]: |
| text_config = TextConfig(**text_config) |
| else: |
| raise NotImplementedError() |
|
|
| return text_config |
|
|
|
|
| class AlignTransformerConfig(PretrainedConfig): |
| def __init__( |
| self, |
| model_type: str = "align_transformer", |
| projector_config=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.model_type = model_type |
| self.projector_config = projector_config |
|
|
| @staticmethod |
| def from_exp_config( |
| align_transformer_config: dict, |
| ): |
| projector_config = align_transformer_config.pop("projector_config", None) |
|
|
| config = Dinov2Config(**align_transformer_config) |
| config = config.to_dict() |
|
|
| align_transformer_config = AlignTransformerConfig( |
| **(config | align_transformer_config), |
| projector_config=projector_config, |
| ) |
|
|
| return align_transformer_config |
|
|
|
|
| class CxrAlignConfig(PretrainedConfig): |
| is_composition = True |
|
|
| def __init__( |
| self, |
| vision_config: dict, |
| text_config: dict, |
| align_transformer_config: dict, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| |
| self.vision_config = VisionConfig.from_exp_config(vision_config) |
|
|
| |
| self.text_config = TextConfig.from_exp_config(text_config) |
|
|
| self.align_transformer_config = AlignTransformerConfig.from_exp_config( |
| align_transformer_config |
| ) |
|
|
| self.kwargs = kwargs |
|
|