Z-Image-Turbo-student-adapter / text_encoder /modeling_student_adapter.py
SearchingMan's picture
Z-Image-Turbo with student+adapter text encoder
0179f45 verified
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from .configuration_student_adapter import StudentAdapterConfig
class XAttnBlock(nn.Module):
def __init__(self, dim, heads, ff_mult=4, dropout=0.1):
super().__init__()
self.norm_q = nn.LayerNorm(dim)
self.norm_kv = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
self.norm_ff = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * ff_mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * ff_mult, dim),
nn.Dropout(dropout),
)
def forward(self, q, kv, key_padding_mask=None):
q = q + self.attn(
self.norm_q(q),
self.norm_kv(kv),
self.norm_kv(kv),
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
q = q + self.ff(self.norm_ff(q))
return q
class Adapter(nn.Module):
def __init__(self, s_dim, t_dim, dim=1024, heads=8, blocks=2, ff_mult=4, dropout=0.1):
super().__init__()
self.q_proj = nn.Linear(s_dim, dim)
self.kv_proj = nn.Linear(s_dim, dim)
self.blocks = nn.ModuleList([
XAttnBlock(dim, heads, ff_mult=ff_mult, dropout=dropout)
for _ in range(blocks)
])
self.proj_out = nn.Linear(dim, t_dim)
def forward(self, student_hs, mask):
q = self.q_proj(student_hs)
kv = self.kv_proj(student_hs)
key_padding_mask = ~mask.bool()
for block in self.blocks:
q = block(q, kv, key_padding_mask=key_padding_mask)
out = self.proj_out(q)
out = out.masked_fill(~mask[..., None].bool(), 0)
return out
class StudentAdapterTextEncoder(PreTrainedModel):
config_class = StudentAdapterConfig
base_model_prefix = "student"
def __init__(self, config: StudentAdapterConfig):
super().__init__(config)
student_cfg_dict = dict(config.student_config_dict or {})
if not student_cfg_dict:
raise ValueError("StudentAdapterConfig.student_config_dict is required")
model_type = student_cfg_dict.get("model_type") or config.student_model_type
if model_type is None:
raise ValueError("Missing student model_type")
cfg_kwargs = dict(student_cfg_dict)
cfg_kwargs.pop("model_type", None)
student_cfg = AutoConfig.for_model(model_type, **cfg_kwargs)
self.student = AutoModelForCausalLM.from_config(student_cfg, trust_remote_code=True)
s_dim = int(getattr(self.student.config, "hidden_size", config.student_hidden_size))
t_dim = int(config.teacher_hidden_size)
self.adapter = Adapter(
s_dim=s_dim,
t_dim=t_dim,
dim=config.adapter_dim,
heads=config.adapter_heads,
blocks=config.adapter_blocks,
ff_mult=config.adapter_ff_mult,
dropout=config.adapter_dropout,
)
self.hs_tap_index = int(config.hs_tap_index)
self.post_init()
def _extract_hs(self, outputs, idx: int):
hs = outputs.hidden_states
if hs is None:
raise RuntimeError("Student output_hidden_states is required")
if not (-len(hs) <= idx < len(hs)):
raise IndexError(f"hidden-state index {idx} out of range for len={len(hs)}")
return hs[idx]
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=True, return_dict=True, **kwargs):
if input_ids is None:
raise ValueError("input_ids is required")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# Qwen3 student model expects long dtype; pipeline may pass bool masks
if attention_mask.dtype == torch.bool:
attention_mask = attention_mask.long()
out = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
hs_list = list(out.hidden_states)
s_hs = self._extract_hs(out, self.hs_tap_index)
ad_dtype = next(self.adapter.parameters()).dtype
if s_hs.dtype != ad_dtype:
s_hs = s_hs.to(ad_dtype)
adapted = self.adapter(s_hs, attention_mask)
if len(hs_list) >= 2:
hs_list[-2] = adapted
else:
hs_list.append(adapted)
if not return_dict:
return (adapted, None, tuple(hs_list), None)
return BaseModelOutputWithPast(
last_hidden_state=adapted,
past_key_values=None,
hidden_states=tuple(hs_list),
attentions=None,
)