Instructions to use SearchingMan/Z-Image-Turbo-student-adapter with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use SearchingMan/Z-Image-Turbo-student-adapter with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("SearchingMan/Z-Image-Turbo-student-adapter", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
File size: 4,951 Bytes
0179f45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from .configuration_student_adapter import StudentAdapterConfig
class XAttnBlock(nn.Module):
def __init__(self, dim, heads, ff_mult=4, dropout=0.1):
super().__init__()
self.norm_q = nn.LayerNorm(dim)
self.norm_kv = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
self.norm_ff = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * ff_mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * ff_mult, dim),
nn.Dropout(dropout),
)
def forward(self, q, kv, key_padding_mask=None):
q = q + self.attn(
self.norm_q(q),
self.norm_kv(kv),
self.norm_kv(kv),
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
q = q + self.ff(self.norm_ff(q))
return q
class Adapter(nn.Module):
def __init__(self, s_dim, t_dim, dim=1024, heads=8, blocks=2, ff_mult=4, dropout=0.1):
super().__init__()
self.q_proj = nn.Linear(s_dim, dim)
self.kv_proj = nn.Linear(s_dim, dim)
self.blocks = nn.ModuleList([
XAttnBlock(dim, heads, ff_mult=ff_mult, dropout=dropout)
for _ in range(blocks)
])
self.proj_out = nn.Linear(dim, t_dim)
def forward(self, student_hs, mask):
q = self.q_proj(student_hs)
kv = self.kv_proj(student_hs)
key_padding_mask = ~mask.bool()
for block in self.blocks:
q = block(q, kv, key_padding_mask=key_padding_mask)
out = self.proj_out(q)
out = out.masked_fill(~mask[..., None].bool(), 0)
return out
class StudentAdapterTextEncoder(PreTrainedModel):
config_class = StudentAdapterConfig
base_model_prefix = "student"
def __init__(self, config: StudentAdapterConfig):
super().__init__(config)
student_cfg_dict = dict(config.student_config_dict or {})
if not student_cfg_dict:
raise ValueError("StudentAdapterConfig.student_config_dict is required")
model_type = student_cfg_dict.get("model_type") or config.student_model_type
if model_type is None:
raise ValueError("Missing student model_type")
cfg_kwargs = dict(student_cfg_dict)
cfg_kwargs.pop("model_type", None)
student_cfg = AutoConfig.for_model(model_type, **cfg_kwargs)
self.student = AutoModelForCausalLM.from_config(student_cfg, trust_remote_code=True)
s_dim = int(getattr(self.student.config, "hidden_size", config.student_hidden_size))
t_dim = int(config.teacher_hidden_size)
self.adapter = Adapter(
s_dim=s_dim,
t_dim=t_dim,
dim=config.adapter_dim,
heads=config.adapter_heads,
blocks=config.adapter_blocks,
ff_mult=config.adapter_ff_mult,
dropout=config.adapter_dropout,
)
self.hs_tap_index = int(config.hs_tap_index)
self.post_init()
def _extract_hs(self, outputs, idx: int):
hs = outputs.hidden_states
if hs is None:
raise RuntimeError("Student output_hidden_states is required")
if not (-len(hs) <= idx < len(hs)):
raise IndexError(f"hidden-state index {idx} out of range for len={len(hs)}")
return hs[idx]
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=True, return_dict=True, **kwargs):
if input_ids is None:
raise ValueError("input_ids is required")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# Qwen3 student model expects long dtype; pipeline may pass bool masks
if attention_mask.dtype == torch.bool:
attention_mask = attention_mask.long()
out = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
hs_list = list(out.hidden_states)
s_hs = self._extract_hs(out, self.hs_tap_index)
ad_dtype = next(self.adapter.parameters()).dtype
if s_hs.dtype != ad_dtype:
s_hs = s_hs.to(ad_dtype)
adapted = self.adapter(s_hs, attention_mask)
if len(hs_list) >= 2:
hs_list[-2] = adapted
else:
hs_list.append(adapted)
if not return_dict:
return (adapted, None, tuple(hs_list), None)
return BaseModelOutputWithPast(
last_hidden_state=adapted,
past_key_values=None,
hidden_states=tuple(hs_list),
attentions=None,
)
|