File size: 5,969 Bytes
abd08dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from .z_image_dit import ZImageTransformerBlock
from ..core.gradient import gradient_checkpoint_forward
from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn


class ZImageControlTransformerBlock(ZImageTransformerBlock):
    def __init__(
        self, 
        layer_id: int = 1000,
        dim: int = 3840,
        n_heads: int = 30,
        n_kv_heads: int = 30,
        norm_eps: float = 1e-5,
        qk_norm: bool = True,
        modulation = True,
        block_id = 0
    ):
        super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
        self.block_id = block_id
        if block_id == 0:
            self.before_proj = nn.Linear(self.dim, self.dim)
        self.after_proj = nn.Linear(self.dim, self.dim)

    def forward(self, c, x, **kwargs):
        if self.block_id == 0:
            c = self.before_proj(c) + x
            all_c = []
        else:
            all_c = list(torch.unbind(c))
            c = all_c.pop(-1)

        c = super().forward(c, **kwargs)
        c_skip = self.after_proj(c)
        all_c += [c_skip, c]
        c = torch.stack(all_c)
        return c


class ZImageControlNet(torch.nn.Module):
    def __init__(
        self,
        control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
        control_in_dim=33,
        dim=3840,
        n_refiner_layers=2,
    ):
        super().__init__()
        self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
        self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
        self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
        self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}

    def forward_layers(
        self,
        x,
        cap_feats,
        control_context,
        control_context_item_seqlens,
        kwargs,
        use_gradient_checkpointing=False,
        use_gradient_checkpointing_offload=False,
    ):
        bsz = len(control_context)
        # unified
        cap_item_seqlens = [len(_) for _ in cap_feats]
        control_context_unified = []
        for i in range(bsz):
            control_context_len = control_context_item_seqlens[i]
            cap_len = cap_item_seqlens[i]
            control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
        c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)

        # arguments
        new_kwargs = dict(x=x)
        new_kwargs.update(kwargs)
        
        for layer in self.control_layers:
            c = gradient_checkpoint_forward(
                layer,
                use_gradient_checkpointing=use_gradient_checkpointing,
                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
                c=c, **new_kwargs
            )
 
        hints = torch.unbind(c)[:-1]
        return hints
    
    def forward_refiner(
        self,
        dit,
        x,
        cap_feats,
        control_context,
        kwargs,
        t=None,
        patch_size=2,
        f_patch_size=1,
        use_gradient_checkpointing=False,
        use_gradient_checkpointing_offload=False,
    ):
        # embeddings
        bsz = len(control_context)
        device = control_context[0].device
        (
            control_context,
            control_context_size,
            control_context_pos_ids,
            control_context_inner_pad_mask,
        ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))

        # control_context embed & refine
        control_context_item_seqlens = [len(_) for _ in control_context]
        assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
        control_context_max_item_seqlen = max(control_context_item_seqlens)

        control_context = torch.cat(control_context, dim=0)
        control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)

        # Match t_embedder output dtype to control_context for layerwise casting compatibility
        adaln_input = t.type_as(control_context)
        control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
        control_context = list(control_context.split(control_context_item_seqlens, dim=0))
        control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))

        control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
        control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
        control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
        for i, seq_len in enumerate(control_context_item_seqlens):
            control_context_attn_mask[i, :seq_len] = 1
        c = control_context

        # arguments
        new_kwargs = dict(
            x=x, 
            attn_mask=control_context_attn_mask,
            freqs_cis=control_context_freqs_cis, 
            adaln_input=adaln_input,
        )
        new_kwargs.update(kwargs)
        
        for layer in self.control_noise_refiner:
            c = gradient_checkpoint_forward(
                layer,
                use_gradient_checkpointing=use_gradient_checkpointing,
                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
                c=c, **new_kwargs
            )
 
        hints = torch.unbind(c)[:-1]
        control_context = torch.unbind(c)[-1]

        return hints, control_context, control_context_item_seqlens