| from typing import Optional |
|
|
| import transformers |
|
|
|
|
| class ASRConfig(transformers.PretrainedConfig): |
| """Configuration class for the ASR model. |
| |
| This config combines settings for: |
| - Audio encoder (GLM-ASR/Whisper) |
| - Text decoder (Qwen) |
| - Projector (MLP, MOSA, MoE, QFormer) |
| - Generation parameters |
| - Training options (SpecAugment, LoRA) |
| """ |
|
|
| model_type = "asr_model" |
| is_composition = True |
|
|
| def __init__( |
| self, |
| audio_model_id: str = "zai-org/GLM-ASR-Nano-2512", |
| text_model_id: str = "Qwen/Qwen3-0.6B", |
| attn_implementation: str = "sdpa", |
| model_dtype: str = "bfloat16", |
| num_beams: Optional[int] = None, |
| system_prompt: str = "You are a helpful assistant.", |
| encoder_dim: Optional[int] = None, |
| llm_dim: Optional[int] = None, |
| |
| |
| encoder_conv_layers: Optional[list] = None, |
| audio_sample_rate: int = 16000, |
| projector_pool_stride: int = 4, |
| downsample_rate: int = 5, |
| projector_hidden_dim: Optional[int] = None, |
| projector_type: str = "mlp", |
| projector_num_layers: int = 2, |
| projector_init_std: float = 0.02, |
| projector_dropout: float = 0.0, |
| |
| num_experts: int = 4, |
| num_experts_per_tok: int = 2, |
| router_aux_loss_coef: float = 0.01, |
| |
| qformer_window_size: int = 15, |
| qformer_hidden_size: Optional[int] = None, |
| qformer_num_layers: int = 2, |
| qformer_num_heads: int = 16, |
| qformer_intermediate_size: Optional[int] = None, |
| label_smoothing: float = 0.0, |
| inference_warmup_tokens: int = 10, |
| |
| use_specaugment: bool = False, |
| num_time_masks: int = 2, |
| time_mask_length: int = 10, |
| num_freq_masks: int = 0, |
| freq_mask_length: int = 10, |
| |
| use_lora: bool = False, |
| lora_rank: int = 8, |
| lora_alpha: int = 32, |
| lora_dropout: float = 0.0, |
| lora_target_modules: Optional[list] = None, |
| freeze_projector: bool = False, |
| do_sample: bool = False, |
| enable_thinking: bool = False, |
| temperature: Optional[float] = None, |
| top_p: Optional[float] = None, |
| top_k: Optional[int] = None, |
| max_new_tokens: Optional[int] = None, |
| min_new_tokens: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| length_penalty: Optional[float] = None, |
| no_repeat_ngram_size: Optional[int] = None, |
| use_cache: Optional[bool] = None, |
| **kwargs, |
| ): |
| """Initialize ASR model configuration. |
| |
| Args: |
| audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper) |
| text_model_id: HuggingFace model ID for text decoder (Qwen) |
| attn_implementation: Attention implementation ("sdpa", "flash_attention_2", "eager") |
| model_dtype: Model dtype ("bfloat16", "float16", "float32") |
| projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer") |
| use_lora: Enable LoRA adapters for Stage 2 fine-tuning |
| use_specaugment: Enable SpecAugment data augmentation |
| """ |
| |
| generation_defaults = { |
| "num_beams": 1, |
| "max_new_tokens": 128, |
| "min_new_tokens": 0, |
| "repetition_penalty": 1.0, |
| "length_penalty": 1.0, |
| "no_repeat_ngram_size": 0, |
| "use_cache": True, |
| } |
|
|
| |
| kwargs = {**generation_defaults, **kwargs} |
|
|
| self.audio_model_id = audio_model_id |
| self.text_model_id = text_model_id |
| self.attn_implementation = attn_implementation |
| self.model_dtype = model_dtype |
| self.system_prompt = system_prompt |
| self.encoder_dim = encoder_dim |
| self.llm_dim = llm_dim |
| |
| self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)] |
| self.audio_sample_rate = audio_sample_rate |
| self.projector_init_std = projector_init_std |
| self.projector_pool_stride = projector_pool_stride |
| self.downsample_rate = downsample_rate |
| self.projector_hidden_dim = projector_hidden_dim |
| self.projector_type = projector_type |
| self.projector_num_layers = projector_num_layers |
| self.projector_dropout = projector_dropout |
| |
| self.num_experts = num_experts |
| self.num_experts_per_tok = num_experts_per_tok |
| self.router_aux_loss_coef = router_aux_loss_coef |
| |
| self.qformer_window_size = qformer_window_size |
| self.qformer_hidden_size = qformer_hidden_size |
| self.qformer_num_layers = qformer_num_layers |
| self.qformer_num_heads = qformer_num_heads |
| self.qformer_intermediate_size = qformer_intermediate_size |
| self.label_smoothing = label_smoothing |
| self.inference_warmup_tokens = inference_warmup_tokens |
| |
| self.use_specaugment = use_specaugment |
| self.num_time_masks = num_time_masks |
| self.time_mask_length = time_mask_length |
| self.num_freq_masks = num_freq_masks |
| self.freq_mask_length = freq_mask_length |
| |
| self.use_lora = use_lora |
| self.lora_rank = lora_rank |
| self.lora_alpha = lora_alpha |
| self.lora_dropout = lora_dropout |
| self.lora_target_modules = lora_target_modules or [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ] |
| self.freeze_projector = freeze_projector |
|
|
| |
| def get_gen_param(name, named_value): |
| if named_value is not None: |
| return named_value |
| return kwargs.get(name, generation_defaults[name]) |
|
|
| self.num_beams = get_gen_param("num_beams", num_beams) |
| self.max_new_tokens = get_gen_param("max_new_tokens", max_new_tokens) |
| self.min_new_tokens = get_gen_param("min_new_tokens", min_new_tokens) |
| self.repetition_penalty = get_gen_param("repetition_penalty", repetition_penalty) |
| self.length_penalty = get_gen_param("length_penalty", length_penalty) |
| self.no_repeat_ngram_size = get_gen_param("no_repeat_ngram_size", no_repeat_ngram_size) |
| self.use_cache = get_gen_param("use_cache", use_cache) |
| self.do_sample = do_sample |
| self.enable_thinking = enable_thinking |
| self.temperature = temperature |
| self.top_p = top_p |
| self.top_k = top_k |
|
|
| if "audio_config" not in kwargs: |
| self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id) |
| |
| self.audio_config.dtype = model_dtype |
| else: |
| self.audio_config = kwargs.pop("audio_config") |
|
|
| if "text_config" not in kwargs: |
| self.text_config = transformers.AutoConfig.from_pretrained( |
| text_model_id, trust_remote_code=True |
| ) |
| |
| self.text_config.dtype = model_dtype |
| else: |
| self.text_config = kwargs.pop("text_config") |
|
|
| if isinstance(self.text_config, dict): |
| |
| model_type = self.text_config["model_type"] |
| config_class = transformers.AutoConfig.for_model(model_type).__class__ |
| self.text_config = config_class(**self.text_config) |
|
|
| if isinstance(self.audio_config, dict): |
| model_type = self.audio_config.get("model_type") |
| if model_type: |
| config_class = transformers.AutoConfig.for_model(model_type).__class__ |
| self.audio_config = config_class(**self.audio_config) |
|
|
| super().__init__(**kwargs) |
|
|
| |
| |
| self.encoder = self.audio_config |
|
|
| self.auto_map = { |
| "AutoConfig": "asr_config.ASRConfig", |
| "AutoModel": "asr_modeling.ASRModel", |
| "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel", |
| "AutoProcessor": "asr_processing.ASRProcessor", |
| } |
| self.custom_pipelines = { |
| "automatic-speech-recognition": { |
| "impl": "asr_pipeline.ASRPipeline", |
| "pt": ["AutoModelForSpeechSeq2Seq"], |
| "tf": [], |
| "type": "audio", |
| } |
| } |
| self.architectures = ["ASRModel"] |
| self.pipeline_tag = "automatic-speech-recognition" |
|
|
|
|
| transformers.AutoConfig.register("asr_model", ASRConfig) |
|
|