Spaces:
Running on Zero
Running on Zero
| from .z_image_dit import ZImageTransformerBlock | |
| from ..core.gradient import gradient_checkpoint_forward | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torch | |
| from torch import nn | |
| class ZImageControlTransformerBlock(ZImageTransformerBlock): | |
| def __init__( | |
| self, | |
| layer_id: int = 1000, | |
| dim: int = 3840, | |
| n_heads: int = 30, | |
| n_kv_heads: int = 30, | |
| norm_eps: float = 1e-5, | |
| qk_norm: bool = True, | |
| modulation = True, | |
| block_id = 0 | |
| ): | |
| super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) | |
| self.block_id = block_id | |
| if block_id == 0: | |
| self.before_proj = nn.Linear(self.dim, self.dim) | |
| self.after_proj = nn.Linear(self.dim, self.dim) | |
| def forward(self, c, x, **kwargs): | |
| if self.block_id == 0: | |
| c = self.before_proj(c) + x | |
| all_c = [] | |
| else: | |
| all_c = list(torch.unbind(c)) | |
| c = all_c.pop(-1) | |
| c = super().forward(c, **kwargs) | |
| c_skip = self.after_proj(c) | |
| all_c += [c_skip, c] | |
| c = torch.stack(all_c) | |
| return c | |
| class ZImageControlNet(torch.nn.Module): | |
| def __init__( | |
| self, | |
| control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), | |
| control_in_dim=33, | |
| dim=3840, | |
| n_refiner_layers=2, | |
| ): | |
| super().__init__() | |
| self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places]) | |
| self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)}) | |
| self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)]) | |
| self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14} | |
| def forward_layers( | |
| self, | |
| x, | |
| cap_feats, | |
| control_context, | |
| control_context_item_seqlens, | |
| kwargs, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| ): | |
| bsz = len(control_context) | |
| # unified | |
| cap_item_seqlens = [len(_) for _ in cap_feats] | |
| control_context_unified = [] | |
| for i in range(bsz): | |
| control_context_len = control_context_item_seqlens[i] | |
| cap_len = cap_item_seqlens[i] | |
| control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]])) | |
| c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) | |
| # arguments | |
| new_kwargs = dict(x=x) | |
| new_kwargs.update(kwargs) | |
| for layer in self.control_layers: | |
| c = gradient_checkpoint_forward( | |
| layer, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| c=c, **new_kwargs | |
| ) | |
| hints = torch.unbind(c)[:-1] | |
| return hints | |
| def forward_refiner( | |
| self, | |
| dit, | |
| x, | |
| cap_feats, | |
| control_context, | |
| kwargs, | |
| t=None, | |
| patch_size=2, | |
| f_patch_size=1, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| ): | |
| # embeddings | |
| bsz = len(control_context) | |
| device = control_context[0].device | |
| ( | |
| control_context, | |
| control_context_size, | |
| control_context_pos_ids, | |
| control_context_inner_pad_mask, | |
| ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) | |
| # control_context embed & refine | |
| control_context_item_seqlens = [len(_) for _ in control_context] | |
| assert all(_ % 2 == 0 for _ in control_context_item_seqlens) | |
| control_context_max_item_seqlen = max(control_context_item_seqlens) | |
| control_context = torch.cat(control_context, dim=0) | |
| control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) | |
| # Match t_embedder output dtype to control_context for layerwise casting compatibility | |
| adaln_input = t.type_as(control_context) | |
| control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device) | |
| control_context = list(control_context.split(control_context_item_seqlens, dim=0)) | |
| control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0)) | |
| control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) | |
| control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0) | |
| control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device) | |
| for i, seq_len in enumerate(control_context_item_seqlens): | |
| control_context_attn_mask[i, :seq_len] = 1 | |
| c = control_context | |
| # arguments | |
| new_kwargs = dict( | |
| x=x, | |
| attn_mask=control_context_attn_mask, | |
| freqs_cis=control_context_freqs_cis, | |
| adaln_input=adaln_input, | |
| ) | |
| new_kwargs.update(kwargs) | |
| for layer in self.control_noise_refiner: | |
| c = gradient_checkpoint_forward( | |
| layer, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| c=c, **new_kwargs | |
| ) | |
| hints = torch.unbind(c)[:-1] | |
| control_context = torch.unbind(c)[-1] | |
| return hints, control_context, control_context_item_seqlens |