Daankular commited on
Commit
c74f8a5
·
1 Parent(s): 6b03536

Patch PSHuman attn: use kwargs in Attention.__init__ (diffusers added kv_heads param, breaking positional args)

Browse files
patches/pshuman/mvdiffusion/models_unclip/attn_processors.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+
8
+ from diffusers.models.attention import Attention
9
+ from diffusers.utils.import_utils import is_xformers_available
10
+ from einops import rearrange, repeat
11
+ import math
12
+
13
+ import torch.nn.functional as F
14
+ if is_xformers_available():
15
+ import xformers
16
+ import xformers.ops
17
+ else:
18
+ xformers = None
19
+
20
+ class RowwiseMVAttention(Attention):
21
+ def set_use_memory_efficient_attention_xformers(
22
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
23
+ ):
24
+ processor = XFormersMVAttnProcessor()
25
+ self.set_processor(processor)
26
+ # print("using xformers attention processor")
27
+
28
+ class IPCDAttention(Attention):
29
+ def set_use_memory_efficient_attention_xformers(
30
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
31
+ ):
32
+ processor = XFormersIPCDAttnProcessor()
33
+ self.set_processor(processor)
34
+ # print("using xformers attention processor")
35
+
36
+
37
+
38
+ class XFormersMVAttnProcessor:
39
+ r"""
40
+ Default processor for performing attention-related computations.
41
+ """
42
+
43
+ def __call__(
44
+ self,
45
+ attn: Attention,
46
+ hidden_states,
47
+ encoder_hidden_states=None,
48
+ attention_mask=None,
49
+ temb=None,
50
+ num_views=1,
51
+ multiview_attention=True,
52
+ cd_attention_mid=False
53
+ ):
54
+ # print(num_views)
55
+ residual = hidden_states
56
+
57
+ if attn.spatial_norm is not None:
58
+ hidden_states = attn.spatial_norm(hidden_states, temb)
59
+
60
+ input_ndim = hidden_states.ndim
61
+
62
+ if input_ndim == 4:
63
+ batch_size, channel, height, width = hidden_states.shape
64
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
65
+
66
+ batch_size, sequence_length, _ = (
67
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
68
+ )
69
+ height = int(math.sqrt(sequence_length))
70
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
71
+ # from yuancheng; here attention_mask is None
72
+ if attention_mask is not None:
73
+ # expand our mask's singleton query_tokens dimension:
74
+ # [batch*heads, 1, key_tokens] ->
75
+ # [batch*heads, query_tokens, key_tokens]
76
+ # so that it can be added as a bias onto the attention scores that xformers computes:
77
+ # [batch*heads, query_tokens, key_tokens]
78
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
79
+ _, query_tokens, _ = hidden_states.shape
80
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
81
+
82
+ if attn.group_norm is not None:
83
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
84
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
85
+
86
+ query = attn.to_q(hidden_states)
87
+
88
+ if encoder_hidden_states is None:
89
+ encoder_hidden_states = hidden_states
90
+ elif attn.norm_cross:
91
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
92
+
93
+ key_raw = attn.to_k(encoder_hidden_states)
94
+ value_raw = attn.to_v(encoder_hidden_states)
95
+
96
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
97
+ # pdb.set_trace()
98
+ def transpose(tensor):
99
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
100
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
101
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
102
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
103
+ return tensor
104
+ # print(mvcd_attention)
105
+ # import pdb;pdb.set_trace()
106
+ if cd_attention_mid:
107
+ key = transpose(key_raw)
108
+ value = transpose(value_raw)
109
+ query = transpose(query)
110
+ else:
111
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
112
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
113
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
114
+
115
+
116
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
117
+ key = attn.head_to_batch_dim(key)
118
+ value = attn.head_to_batch_dim(value)
119
+
120
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
121
+ hidden_states = attn.batch_to_head_dim(hidden_states)
122
+
123
+ # linear proj
124
+ hidden_states = attn.to_out[0](hidden_states)
125
+ # dropout
126
+ hidden_states = attn.to_out[1](hidden_states)
127
+
128
+ if cd_attention_mid:
129
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
130
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
131
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
132
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
133
+ else:
134
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
135
+ if input_ndim == 4:
136
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
137
+
138
+ if attn.residual_connection:
139
+ hidden_states = hidden_states + residual
140
+
141
+ hidden_states = hidden_states / attn.rescale_output_factor
142
+
143
+ return hidden_states
144
+
145
+
146
+ class XFormersIPCDAttnProcessor:
147
+ r"""
148
+ Default processor for performing attention-related computations.
149
+ """
150
+
151
+ def process(self,
152
+ attn: Attention,
153
+ hidden_states,
154
+ encoder_hidden_states=None,
155
+ attention_mask=None,
156
+ temb=None,
157
+ num_tasks=2,
158
+ num_views=6):
159
+ ### TODO: num_views
160
+ residual = hidden_states
161
+
162
+ if attn.spatial_norm is not None:
163
+ hidden_states = attn.spatial_norm(hidden_states, temb)
164
+
165
+ input_ndim = hidden_states.ndim
166
+
167
+ if input_ndim == 4:
168
+ batch_size, channel, height, width = hidden_states.shape
169
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
170
+
171
+ batch_size, sequence_length, _ = (
172
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
173
+ )
174
+ height = int(math.sqrt(sequence_length))
175
+ height_st = height // 3
176
+ height_end = height - height_st
177
+
178
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
179
+
180
+ # from yuancheng; here attention_mask is None
181
+ if attention_mask is not None:
182
+ # expand our mask's singleton query_tokens dimension:
183
+ # [batch*heads, 1, key_tokens] ->
184
+ # [batch*heads, query_tokens, key_tokens]
185
+ # so that it can be added as a bias onto the attention scores that xformers computes:
186
+ # [batch*heads, query_tokens, key_tokens]
187
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
188
+ _, query_tokens, _ = hidden_states.shape
189
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
190
+
191
+ if attn.group_norm is not None:
192
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
193
+
194
+ query = attn.to_q(hidden_states)
195
+
196
+ if encoder_hidden_states is None:
197
+ encoder_hidden_states = hidden_states
198
+ elif attn.norm_cross:
199
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
200
+
201
+ key = attn.to_k(encoder_hidden_states)
202
+ value = attn.to_v(encoder_hidden_states)
203
+
204
+ assert num_tasks == 2 # only support two tasks now
205
+
206
+
207
+ # ip attn
208
+ # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views)
209
+ # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :]
210
+ # print(body_hidden_states.shape, face_hidden_states.shape)
211
+ # import pdb;pdb.set_trace()
212
+ # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view)
213
+ # hidden_states = rearrange(
214
+ # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1),
215
+ # 'b v l c -> (b v) l c')
216
+
217
+ # face cross attention
218
+ # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1)
219
+ # ip_key = attn.to_k_ip(ip_hidden_states)
220
+ # ip_value = attn.to_v_ip(ip_hidden_states)
221
+ # ip_key = attn.head_to_batch_dim(ip_key).contiguous()
222
+ # ip_value = attn.head_to_batch_dim(ip_value).contiguous()
223
+ # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous()
224
+ # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask)
225
+ # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
226
+ # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states)
227
+ # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states)
228
+ # import pdb;pdb.set_trace()
229
+
230
+
231
+ def transpose(tensor):
232
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
233
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
234
+ # tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1)
235
+ # body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c
236
+ # face = face.repeat(1, num_views, 1, 1) # b,v,l,c
237
+ # tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c
238
+ # tensor = rearrange(tensor, "b v l c -> (b v) l c")
239
+ return tensor
240
+ key = transpose(key)
241
+ value = transpose(value)
242
+ query = transpose(query)
243
+
244
+ query = attn.head_to_batch_dim(query).contiguous()
245
+ key = attn.head_to_batch_dim(key).contiguous()
246
+ value = attn.head_to_batch_dim(value).contiguous()
247
+
248
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
249
+ hidden_states = attn.batch_to_head_dim(hidden_states)
250
+
251
+ # linear proj
252
+ hidden_states = attn.to_out[0](hidden_states)
253
+ # dropout
254
+ hidden_states = attn.to_out[1](hidden_states)
255
+ hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c
256
+
257
+ hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height)
258
+ face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach()
259
+ face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c')
260
+ hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal
261
+ hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal
262
+ # hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal
263
+ hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c")
264
+
265
+
266
+ hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height)
267
+ face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach()
268
+ face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c')
269
+ hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color
270
+ hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color
271
+ # hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color
272
+ hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c")
273
+
274
+ hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
275
+
276
+
277
+ if input_ndim == 4:
278
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
279
+
280
+ if attn.residual_connection:
281
+ hidden_states = hidden_states + residual
282
+
283
+ hidden_states = hidden_states / attn.rescale_output_factor
284
+ return hidden_states
285
+
286
+ def __call__(
287
+ self,
288
+ attn: Attention,
289
+ hidden_states,
290
+ encoder_hidden_states=None,
291
+ attention_mask=None,
292
+ temb=None,
293
+ num_tasks=2,
294
+ ):
295
+ hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks)
296
+ # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c')
297
+ # body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :]
298
+ # import pdb;pdb.set_trace()
299
+ # hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1)
300
+ # hidden_states = rearrange(
301
+ # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1),
302
+ # 'b v l c -> (b v) l c')
303
+ return hidden_states
304
+
305
+ class IPCrossAttn(Attention):
306
+ r"""
307
+ Attention processor for IP-Adapater.
308
+ Args:
309
+ hidden_size (`int`):
310
+ The hidden size of the attention layer.
311
+ cross_attention_dim (`int`):
312
+ The number of channels in the `encoder_hidden_states`.
313
+ scale (`float`, defaults to 1.0):
314
+ the weight scale of image prompt.
315
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
316
+ The context length of the image features.
317
+ """
318
+
319
+ def __init__(self,
320
+ query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0):
321
+ super().__init__(query_dim=query_dim, cross_attention_dim=cross_attention_dim, heads=heads, dim_head=dim_head, dropout=dropout, bias=bias, upcast_attention=upcast_attention)
322
+
323
+ self.ip_scale = ip_scale
324
+ # self.num_tokens = num_tokens
325
+
326
+ # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
327
+ # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
328
+
329
+ # self.to_out_ip = nn.ModuleList([])
330
+ # self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias))
331
+ # self.to_out_ip.append(nn.Dropout(dropout))
332
+ # nn.init.zeros_(self.to_k_ip.weight.data)
333
+ # nn.init.zeros_(self.to_v_ip.weight.data)
334
+
335
+ def set_use_memory_efficient_attention_xformers(
336
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
337
+ ):
338
+ processor = XFormersIPCrossAttnProcessor()
339
+ self.set_processor(processor)
340
+
341
+ class XFormersIPCrossAttnProcessor:
342
+
343
+ def __call__(
344
+ self,
345
+ attn: Attention,
346
+ hidden_states,
347
+ encoder_hidden_states=None,
348
+ attention_mask=None,
349
+ temb=None,
350
+ num_views=1
351
+ ):
352
+ residual = hidden_states
353
+ if attn.spatial_norm is not None:
354
+ hidden_states = attn.spatial_norm(hidden_states, temb)
355
+
356
+ input_ndim = hidden_states.ndim
357
+
358
+ if input_ndim == 4:
359
+ batch_size, channel, height, width = hidden_states.shape
360
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
361
+
362
+ batch_size, sequence_length, _ = (
363
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
364
+ )
365
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
366
+
367
+ if attn.group_norm is not None:
368
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
369
+
370
+ query = attn.to_q(hidden_states)
371
+
372
+ key = attn.to_k(encoder_hidden_states)
373
+ value = attn.to_v(encoder_hidden_states)
374
+
375
+ query = attn.head_to_batch_dim(query).contiguous()
376
+ key = attn.head_to_batch_dim(key).contiguous()
377
+ value = attn.head_to_batch_dim(value).contiguous()
378
+
379
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
380
+ hidden_states = attn.batch_to_head_dim(hidden_states)
381
+
382
+ # ip attn
383
+ # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views)
384
+ # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :]
385
+ # print(body_hidden_states.shape, face_hidden_states.shape)
386
+ # import pdb;pdb.set_trace()
387
+ # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view)
388
+ # hidden_states = rearrange(
389
+ # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1),
390
+ # 'b v l c -> (b v) l c')
391
+
392
+ # face cross attention
393
+ # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1)
394
+ # ip_key = attn.to_k_ip(ip_hidden_states)
395
+ # ip_value = attn.to_v_ip(ip_hidden_states)
396
+ # ip_key = attn.head_to_batch_dim(ip_key).contiguous()
397
+ # ip_value = attn.head_to_batch_dim(ip_value).contiguous()
398
+ # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous()
399
+ # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask)
400
+ # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
401
+ # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states)
402
+ # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states)
403
+ # import pdb;pdb.set_trace()
404
+
405
+ # body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states
406
+ # hidden_states = rearrange(
407
+ # torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1),
408
+ # 'b v l c -> (b v) l c')
409
+ # import pdb;pdb.set_trace()
410
+ #
411
+ hidden_states = attn.to_out[0](hidden_states)
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ if input_ndim == 4:
415
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
+
417
+ if attn.residual_connection:
418
+ hidden_states = hidden_states + residual
419
+
420
+ hidden_states = hidden_states / attn.rescale_output_factor
421
+
422
+
423
+ # TODO: region control
424
+ # region control
425
+ # if len(region_control.prompt_image_conditioning) == 1:
426
+ # region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
427
+ # if region_mask is not None:
428
+ # h, w = region_mask.shape[:2]
429
+ # ratio = (h * w / query.shape[1]) ** 0.5
430
+ # mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
431
+ # else:
432
+ # mask = torch.ones_like(ip_hidden_states)
433
+ # ip_hidden_states = ip_hidden_states * mask
434
+
435
+ return hidden_states
436
+
437
+
438
+ class RowwiseMVProcessor:
439
+ r"""
440
+ Default processor for performing attention-related computations.
441
+ """
442
+
443
+ def __call__(
444
+ self,
445
+ attn: Attention,
446
+ hidden_states,
447
+ encoder_hidden_states=None,
448
+ attention_mask=None,
449
+ temb=None,
450
+ num_views=1,
451
+ cd_attention_mid=False
452
+ ):
453
+ residual = hidden_states
454
+
455
+ if attn.spatial_norm is not None:
456
+ hidden_states = attn.spatial_norm(hidden_states, temb)
457
+
458
+ input_ndim = hidden_states.ndim
459
+
460
+ if input_ndim == 4:
461
+ batch_size, channel, height, width = hidden_states.shape
462
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
463
+
464
+ batch_size, sequence_length, _ = (
465
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
466
+ )
467
+ height = int(math.sqrt(sequence_length))
468
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
469
+
470
+ if attn.group_norm is not None:
471
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
472
+
473
+ query = attn.to_q(hidden_states)
474
+
475
+ if encoder_hidden_states is None:
476
+ encoder_hidden_states = hidden_states
477
+ elif attn.norm_cross:
478
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
479
+
480
+ key = attn.to_k(encoder_hidden_states)
481
+ value = attn.to_v(encoder_hidden_states)
482
+
483
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
484
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
485
+ # pdb.set_trace()
486
+ # multi-view self-attention
487
+ def transpose(tensor):
488
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
489
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
490
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
491
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
492
+ return tensor
493
+
494
+ if cd_attention_mid:
495
+ key = transpose(key)
496
+ value = transpose(value)
497
+ query = transpose(query)
498
+ else:
499
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
500
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
501
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
502
+
503
+ query = attn.head_to_batch_dim(query).contiguous()
504
+ key = attn.head_to_batch_dim(key).contiguous()
505
+ value = attn.head_to_batch_dim(value).contiguous()
506
+
507
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
508
+ hidden_states = torch.bmm(attention_probs, value)
509
+ hidden_states = attn.batch_to_head_dim(hidden_states)
510
+
511
+ # linear proj
512
+ hidden_states = attn.to_out[0](hidden_states)
513
+ # dropout
514
+ hidden_states = attn.to_out[1](hidden_states)
515
+ if cd_attention_mid:
516
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
517
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
518
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
519
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
520
+ else:
521
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
522
+ if input_ndim == 4:
523
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
524
+
525
+ if attn.residual_connection:
526
+ hidden_states = hidden_states + residual
527
+
528
+ hidden_states = hidden_states / attn.rescale_output_factor
529
+
530
+ return hidden_states
531
+
532
+
533
+ class CDAttention(Attention):
534
+ # def __init__(self, ip_scale,
535
+ # query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor):
536
+ # super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor)
537
+
538
+ # self.ip_scale = ip_scale
539
+
540
+ # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
541
+ # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
542
+ # nn.init.zeros_(self.to_k_ip.weight.data)
543
+ # nn.init.zeros_(self.to_v_ip.weight.data)
544
+
545
+
546
+ def set_use_memory_efficient_attention_xformers(
547
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
548
+ ):
549
+ processor = XFormersCDAttnProcessor()
550
+ self.set_processor(processor)
551
+ # print("using xformers attention processor")
552
+
553
+ class XFormersCDAttnProcessor:
554
+ r"""
555
+ Default processor for performing attention-related computations.
556
+ """
557
+
558
+ def __call__(
559
+ self,
560
+ attn: Attention,
561
+ hidden_states,
562
+ encoder_hidden_states=None,
563
+ attention_mask=None,
564
+ temb=None,
565
+ num_tasks=2
566
+ ):
567
+
568
+ residual = hidden_states
569
+
570
+ if attn.spatial_norm is not None:
571
+ hidden_states = attn.spatial_norm(hidden_states, temb)
572
+
573
+ input_ndim = hidden_states.ndim
574
+
575
+ if input_ndim == 4:
576
+ batch_size, channel, height, width = hidden_states.shape
577
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
578
+
579
+ batch_size, sequence_length, _ = (
580
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
581
+ )
582
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
583
+
584
+
585
+ if attn.group_norm is not None:
586
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
587
+
588
+ query = attn.to_q(hidden_states)
589
+
590
+ if encoder_hidden_states is None:
591
+ encoder_hidden_states = hidden_states
592
+ elif attn.norm_cross:
593
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
594
+
595
+ key = attn.to_k(encoder_hidden_states)
596
+ value = attn.to_v(encoder_hidden_states)
597
+
598
+ assert num_tasks == 2 # only support two tasks now
599
+
600
+ def transpose(tensor):
601
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
602
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
603
+ return tensor
604
+ key = transpose(key)
605
+ value = transpose(value)
606
+ query = transpose(query)
607
+
608
+
609
+ query = attn.head_to_batch_dim(query).contiguous()
610
+ key = attn.head_to_batch_dim(key).contiguous()
611
+ value = attn.head_to_batch_dim(value).contiguous()
612
+
613
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
614
+ hidden_states = attn.batch_to_head_dim(hidden_states)
615
+
616
+ # linear proj
617
+ hidden_states = attn.to_out[0](hidden_states)
618
+ # dropout
619
+ hidden_states = attn.to_out[1](hidden_states)
620
+
621
+ hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
622
+ if input_ndim == 4:
623
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
624
+
625
+ if attn.residual_connection:
626
+ hidden_states = hidden_states + residual
627
+
628
+ hidden_states = hidden_states / attn.rescale_output_factor
629
+
630
+ return hidden_states
631
+