from transformers import PretrainedConfig class CLSPConfig(PretrainedConfig): model_type = "clsp" def __init__( self, feature_dim: int = 128, output_downsampling_factor: int = 2, downsampling_factor: str = "1,2,4,8,4,2,1", num_encoder_layers: str = "1,2,3,4,1,1,1", encoder_dim: str = "1280,1280,1280,1280,1280,1280,1280", encoder_unmasked_dim: str = "768,768,768,768,768,768,768", query_head_dim: str = "32", pos_head_dim: str = "4", value_head_dim: str = "12", pos_dim: int = 48, num_heads: str = "8,8,8,8,8,8,8", feedforward_dim: str = "3840,3840,3840,3840,3840,3840,3840", cnn_module_kernel: str = "31,31,15,15,15,31,31", causal: bool = False, chunk_size: str = "-1", left_context_frames: str = "-1", text_encoder_dim: int = 768, joint_dim: int = 512, **kwargs, ): super().__init__(**kwargs) # SPEAR encoder related self.feature_dim = feature_dim self.output_downsampling_factor = output_downsampling_factor self.downsampling_factor = downsampling_factor self.num_encoder_layers = num_encoder_layers self.encoder_dim = encoder_dim self.encoder_unmasked_dim = encoder_unmasked_dim self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim self.value_head_dim = value_head_dim self.pos_dim = pos_dim self.num_heads = num_heads self.feedforward_dim = feedforward_dim self.cnn_module_kernel = cnn_module_kernel self.causal = causal self.chunk_size = chunk_size self.left_context_frames = left_context_frames self.text_encoder_dim = text_encoder_dim self.joint_dim = joint_dim