from transformers import PretrainedConfig class StudentAdapterConfig(PretrainedConfig): model_type = "zimage_student_adapter" def __init__( self, student_config_dict=None, student_model_type=None, hs_tap_index=-2, adapter_dim=1024, adapter_heads=8, adapter_blocks=2, adapter_ff_mult=4, adapter_dropout=0.1, teacher_hidden_size=None, student_hidden_size=None, **kwargs, ): super().__init__(**kwargs) self.student_config_dict = student_config_dict or {} self.student_model_type = student_model_type self.hs_tap_index = int(hs_tap_index) self.adapter_dim = int(adapter_dim) self.adapter_heads = int(adapter_heads) self.adapter_blocks = int(adapter_blocks) self.adapter_ff_mult = int(adapter_ff_mult) self.adapter_dropout = float(adapter_dropout) self.teacher_hidden_size = teacher_hidden_size self.student_hidden_size = student_hidden_size