import torch import torch.nn as nn from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from .configuration_student_adapter import StudentAdapterConfig class XAttnBlock(nn.Module): def __init__(self, dim, heads, ff_mult=4, dropout=0.1): super().__init__() self.norm_q = nn.LayerNorm(dim) self.norm_kv = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True) self.norm_ff = nn.LayerNorm(dim) self.ff = nn.Sequential( nn.Linear(dim, dim * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * ff_mult, dim), nn.Dropout(dropout), ) def forward(self, q, kv, key_padding_mask=None): q = q + self.attn( self.norm_q(q), self.norm_kv(kv), self.norm_kv(kv), key_padding_mask=key_padding_mask, need_weights=False, )[0] q = q + self.ff(self.norm_ff(q)) return q class Adapter(nn.Module): def __init__(self, s_dim, t_dim, dim=1024, heads=8, blocks=2, ff_mult=4, dropout=0.1): super().__init__() self.q_proj = nn.Linear(s_dim, dim) self.kv_proj = nn.Linear(s_dim, dim) self.blocks = nn.ModuleList([ XAttnBlock(dim, heads, ff_mult=ff_mult, dropout=dropout) for _ in range(blocks) ]) self.proj_out = nn.Linear(dim, t_dim) def forward(self, student_hs, mask): q = self.q_proj(student_hs) kv = self.kv_proj(student_hs) key_padding_mask = ~mask.bool() for block in self.blocks: q = block(q, kv, key_padding_mask=key_padding_mask) out = self.proj_out(q) out = out.masked_fill(~mask[..., None].bool(), 0) return out class StudentAdapterTextEncoder(PreTrainedModel): config_class = StudentAdapterConfig base_model_prefix = "student" def __init__(self, config: StudentAdapterConfig): super().__init__(config) student_cfg_dict = dict(config.student_config_dict or {}) if not student_cfg_dict: raise ValueError("StudentAdapterConfig.student_config_dict is required") model_type = student_cfg_dict.get("model_type") or config.student_model_type if model_type is None: raise ValueError("Missing student model_type") cfg_kwargs = dict(student_cfg_dict) cfg_kwargs.pop("model_type", None) student_cfg = AutoConfig.for_model(model_type, **cfg_kwargs) self.student = AutoModelForCausalLM.from_config(student_cfg, trust_remote_code=True) s_dim = int(getattr(self.student.config, "hidden_size", config.student_hidden_size)) t_dim = int(config.teacher_hidden_size) self.adapter = Adapter( s_dim=s_dim, t_dim=t_dim, dim=config.adapter_dim, heads=config.adapter_heads, blocks=config.adapter_blocks, ff_mult=config.adapter_ff_mult, dropout=config.adapter_dropout, ) self.hs_tap_index = int(config.hs_tap_index) self.post_init() def _extract_hs(self, outputs, idx: int): hs = outputs.hidden_states if hs is None: raise RuntimeError("Student output_hidden_states is required") if not (-len(hs) <= idx < len(hs)): raise IndexError(f"hidden-state index {idx} out of range for len={len(hs)}") return hs[idx] def forward(self, input_ids=None, attention_mask=None, output_hidden_states=True, return_dict=True, **kwargs): if input_ids is None: raise ValueError("input_ids is required") if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.long) # Qwen3 student model expects long dtype; pipeline may pass bool masks if attention_mask.dtype == torch.bool: attention_mask = attention_mask.long() out = self.student( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, **kwargs, ) hs_list = list(out.hidden_states) s_hs = self._extract_hs(out, self.hs_tap_index) ad_dtype = next(self.adapter.parameters()).dtype if s_hs.dtype != ad_dtype: s_hs = s_hs.to(ad_dtype) adapted = self.adapter(s_hs, attention_mask) if len(hs_list) >= 2: hs_list[-2] = adapted else: hs_list.append(adapted) if not return_dict: return (adapted, None, tuple(hs_list), None) return BaseModelOutputWithPast( last_hidden_state=adapted, past_key_values=None, hidden_states=tuple(hs_list), attentions=None, )