firstkillday commited on
Commit
306a77b
·
verified ·
1 Parent(s): 9afeae9

Upload src/pipelines/pipeline_echomimicv2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipelines/pipeline_echomimicv2.py +625 -625
src/pipelines/pipeline_echomimicv2.py CHANGED
@@ -1,625 +1,625 @@
1
- import inspect
2
- import math
3
- from dataclasses import dataclass
4
- from typing import Callable, List, Optional, Union
5
-
6
- import numpy as np
7
- import torch
8
- from diffusers import DiffusionPipeline
9
- import torch.nn.functional as F
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.schedulers import (
12
- DDIMScheduler,
13
- DPMSolverMultistepScheduler,
14
- EulerAncestralDiscreteScheduler,
15
- EulerDiscreteScheduler,
16
- LMSDiscreteScheduler,
17
- PNDMScheduler,
18
- )
19
- from diffusers.utils import BaseOutput, is_accelerate_available
20
- from diffusers.utils.torch_utils import randn_tensor
21
- from einops import rearrange
22
- from tqdm import tqdm
23
-
24
- from src.models.mutual_self_attention import ReferenceAttentionControl
25
- from src.pipelines.context import get_context_scheduler
26
- from src.pipelines.utils import get_tensor_interpolation_method
27
-
28
-
29
- @dataclass
30
- class EchoMimicV2PipelineOutput(BaseOutput):
31
- videos: Union[torch.Tensor, np.ndarray]
32
-
33
-
34
- class EchoMimicV2Pipeline(DiffusionPipeline):
35
-
36
- def __init__(
37
- self,
38
- vae,
39
- reference_unet,
40
- denoising_unet,
41
- audio_guider,
42
- pose_encoder,
43
- scheduler: Union[
44
- DDIMScheduler,
45
- PNDMScheduler,
46
- LMSDiscreteScheduler,
47
- EulerDiscreteScheduler,
48
- EulerAncestralDiscreteScheduler,
49
- DPMSolverMultistepScheduler,
50
- ],
51
- image_proj_model=None,
52
- tokenizer=None,
53
- text_encoder=None,
54
- ):
55
- super().__init__()
56
-
57
- self.register_modules(
58
- vae=vae,
59
- reference_unet=reference_unet,
60
- denoising_unet=denoising_unet,
61
- audio_guider=audio_guider,
62
- pose_encoder=pose_encoder,
63
- scheduler=scheduler,
64
- image_proj_model=image_proj_model,
65
- tokenizer=tokenizer,
66
- text_encoder=text_encoder,
67
- # audio_feature_mapper=audio_feature_mapper
68
- )
69
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
70
- self.ref_image_processor = VaeImageProcessor(
71
- vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
72
- )
73
-
74
- def enable_vae_slicing(self):
75
- self.vae.enable_slicing()
76
-
77
- def disable_vae_slicing(self):
78
- self.vae.disable_slicing()
79
-
80
- def enable_sequential_cpu_offload(self, gpu_id=0):
81
- if is_accelerate_available():
82
- from accelerate import cpu_offload
83
- else:
84
- raise ImportError("Please install accelerate via `pip install accelerate`")
85
-
86
- device = torch.device(f"cuda:{gpu_id}")
87
-
88
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
89
- if cpu_offloaded_model is not None:
90
- cpu_offload(cpu_offloaded_model, device)
91
-
92
- @property
93
- def _execution_device(self):
94
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
95
- return self.device
96
- for module in self.unet.modules():
97
- if (
98
- hasattr(module, "_hf_hook")
99
- and hasattr(module._hf_hook, "execution_device")
100
- and module._hf_hook.execution_device is not None
101
- ):
102
- return torch.device(module._hf_hook.execution_device)
103
- return self.device
104
-
105
- def decode_latents(self, latents):
106
- video_length = latents.shape[2]
107
- latents = 1 / 0.18215 * latents
108
- latents = rearrange(latents, "b c f h w -> (b f) c h w")
109
- # video = self.vae.decode(latents).sample
110
- video = []
111
- for frame_idx in tqdm(range(latents.shape[0])):
112
- video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
113
- video = torch.cat(video)
114
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
115
- video = (video / 2 + 0.5).clamp(0, 1)
116
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
117
- video = video.cpu().float().numpy()
118
- return video
119
-
120
- def prepare_extra_step_kwargs(self, generator, eta):
121
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
122
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
123
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
124
- # and should be between [0, 1]
125
-
126
- accepts_eta = "eta" in set(
127
- inspect.signature(self.scheduler.step).parameters.keys()
128
- )
129
- extra_step_kwargs = {}
130
- if accepts_eta:
131
- extra_step_kwargs["eta"] = eta
132
-
133
- # check if the scheduler accepts generator
134
- accepts_generator = "generator" in set(
135
- inspect.signature(self.scheduler.step).parameters.keys()
136
- )
137
- if accepts_generator:
138
- extra_step_kwargs["generator"] = generator
139
- return extra_step_kwargs
140
-
141
- def prepare_latents_bp(
142
- self,
143
- batch_size,
144
- num_channels_latents,
145
- width,
146
- height,
147
- video_length,
148
- dtype,
149
- device,
150
- generator,
151
- latents=None,
152
- ):
153
- shape = (
154
- batch_size,
155
- num_channels_latents,
156
- video_length,
157
- height // self.vae_scale_factor,
158
- width // self.vae_scale_factor,
159
- )
160
- if isinstance(generator, list) and len(generator) != batch_size:
161
- raise ValueError(
162
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
163
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
164
- )
165
-
166
- if latents is None:
167
- latents = randn_tensor(
168
- shape, generator=generator, device=device, dtype=dtype
169
- )
170
- else:
171
- latents = latents.to(device)
172
-
173
- # scale the initial noise by the standard deviation required by the scheduler
174
- latents = latents * self.scheduler.init_noise_sigma
175
- return latents
176
-
177
- def prepare_latents(
178
- self,
179
- batch_size,
180
- num_channels_latents,
181
- width,
182
- height,
183
- video_length,
184
- dtype,
185
- device,
186
- generator,
187
- context_frame_length
188
- ):
189
- shape = (
190
- batch_size,
191
- num_channels_latents,
192
- # context_frame_length,
193
- video_length,
194
- height // self.vae_scale_factor,
195
- width // self.vae_scale_factor,
196
- )
197
-
198
- if isinstance(generator, list) and len(generator) != batch_size:
199
- raise ValueError(
200
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
201
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
202
- )
203
-
204
- latents_seg = randn_tensor(
205
- shape, generator=generator, device=device, dtype=dtype
206
- )
207
- latents = latents_seg
208
-
209
- # scale the initial noise by the standard deviation required by the scheduler
210
- latents = latents * self.scheduler.init_noise_sigma
211
- print(f"latents shape:{latents.shape}, video_length:{video_length}")
212
- return latents
213
- def prepare_latents_smooth(
214
- self,
215
- batch_size,
216
- num_channels_latents,
217
- width,
218
- height,
219
- video_length,
220
- dtype,
221
- device,
222
- generator,
223
- context_frame_length
224
- ):
225
- shape = (
226
- batch_size,
227
- num_channels_latents,
228
- # context_frame_length,
229
- video_length,
230
- height // self.vae_scale_factor,
231
- width // self.vae_scale_factor,
232
- )
233
-
234
- if isinstance(generator, list) and len(generator) != batch_size:
235
- raise ValueError(
236
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
237
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
238
- )
239
-
240
- latents_seg = randn_tensor(
241
- shape, generator=generator, device=device, dtype=dtype
242
- )
243
-
244
- latents = latents_seg
245
-
246
- latents = torch.clamp(latents_seg, -1.5, 1.5)
247
-
248
-
249
- # scale the initial noise by the standard deviation required by the scheduler
250
- latents = latents * self.scheduler.init_noise_sigma
251
- print(f"latents shape:{latents.shape}, video_length:{video_length}")
252
-
253
- return latents
254
-
255
- def _encode_prompt(
256
- self,
257
- prompt,
258
- device,
259
- num_videos_per_prompt,
260
- do_classifier_free_guidance,
261
- negative_prompt,
262
- ):
263
- batch_size = len(prompt) if isinstance(prompt, list) else 1
264
-
265
- text_inputs = self.tokenizer(
266
- prompt,
267
- padding="max_length",
268
- max_length=self.tokenizer.model_max_length,
269
- truncation=True,
270
- return_tensors="pt",
271
- )
272
- text_input_ids = text_inputs.input_ids
273
- untruncated_ids = self.tokenizer(
274
- prompt, padding="longest", return_tensors="pt"
275
- ).input_ids
276
-
277
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
278
- text_input_ids, untruncated_ids
279
- ):
280
- removed_text = self.tokenizer.batch_decode(
281
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
282
- )
283
-
284
- if (
285
- hasattr(self.text_encoder.config, "use_attention_mask")
286
- and self.text_encoder.config.use_attention_mask
287
- ):
288
- attention_mask = text_inputs.attention_mask.to(device)
289
- else:
290
- attention_mask = None
291
-
292
- text_embeddings = self.text_encoder(
293
- text_input_ids.to(device),
294
- attention_mask=attention_mask,
295
- )
296
- text_embeddings = text_embeddings[0]
297
-
298
- # duplicate text embeddings for each generation per prompt, using mps friendly method
299
- bs_embed, seq_len, _ = text_embeddings.shape
300
- text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
301
- text_embeddings = text_embeddings.view(
302
- bs_embed * num_videos_per_prompt, seq_len, -1
303
- )
304
-
305
- # get unconditional embeddings for classifier free guidance
306
- if do_classifier_free_guidance:
307
- uncond_tokens: List[str]
308
- if negative_prompt is None:
309
- uncond_tokens = [""] * batch_size
310
- elif type(prompt) is not type(negative_prompt):
311
- raise TypeError(
312
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
313
- f" {type(prompt)}."
314
- )
315
- elif isinstance(negative_prompt, str):
316
- uncond_tokens = [negative_prompt]
317
- elif batch_size != len(negative_prompt):
318
- raise ValueError(
319
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
320
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
321
- " the batch size of `prompt`."
322
- )
323
- else:
324
- uncond_tokens = negative_prompt
325
-
326
- max_length = text_input_ids.shape[-1]
327
- uncond_input = self.tokenizer(
328
- uncond_tokens,
329
- padding="max_length",
330
- max_length=max_length,
331
- truncation=True,
332
- return_tensors="pt",
333
- )
334
-
335
- if (
336
- hasattr(self.text_encoder.config, "use_attention_mask")
337
- and self.text_encoder.config.use_attention_mask
338
- ):
339
- attention_mask = uncond_input.attention_mask.to(device)
340
- else:
341
- attention_mask = None
342
-
343
- uncond_embeddings = self.text_encoder(
344
- uncond_input.input_ids.to(device),
345
- attention_mask=attention_mask,
346
- )
347
- uncond_embeddings = uncond_embeddings[0]
348
-
349
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
350
- seq_len = uncond_embeddings.shape[1]
351
- uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
352
- uncond_embeddings = uncond_embeddings.view(
353
- batch_size * num_videos_per_prompt, seq_len, -1
354
- )
355
-
356
- # For classifier free guidance, we need to do two forward passes.
357
- # Here we concatenate the unconditional and text embeddings into a single batch
358
- # to avoid doing two forward passes
359
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
360
-
361
- return text_embeddings
362
-
363
- def interpolate_latents(
364
- self, latents: torch.Tensor, interpolation_factor: int, device
365
- ):
366
- if interpolation_factor < 2:
367
- return latents
368
-
369
- new_latents = torch.zeros(
370
- (
371
- latents.shape[0],
372
- latents.shape[1],
373
- ((latents.shape[2] - 1) * interpolation_factor) + 1,
374
- latents.shape[3],
375
- latents.shape[4],
376
- ),
377
- device=latents.device,
378
- dtype=latents.dtype,
379
- )
380
-
381
- org_video_length = latents.shape[2]
382
- rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
383
-
384
- new_index = 0
385
-
386
- v0 = None
387
- v1 = None
388
-
389
- for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
390
- v0 = latents[:, :, i0, :, :]
391
- v1 = latents[:, :, i1, :, :]
392
-
393
- new_latents[:, :, new_index, :, :] = v0
394
- new_index += 1
395
-
396
- for f in rate:
397
- v = get_tensor_interpolation_method()(
398
- v0.to(device=device), v1.to(device=device), f
399
- )
400
- new_latents[:, :, new_index, :, :] = v.to(latents.device)
401
- new_index += 1
402
-
403
- new_latents[:, :, new_index, :, :] = v1
404
- new_index += 1
405
-
406
- return new_latents
407
-
408
- @torch.no_grad()
409
- def __call__(
410
- self,
411
- ref_image,
412
- audio_path,
413
- poses_tensor,
414
- width,
415
- height,
416
- video_length,
417
- num_inference_steps,
418
- guidance_scale,
419
- num_images_per_prompt=1,
420
- eta: float = 0.0,
421
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
422
- output_type: Optional[str] = "tensor",
423
- return_dict: bool = True,
424
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
425
- callback_steps: Optional[int] = 1,
426
- context_schedule="uniform",
427
- context_frames=12,
428
- context_stride=1,
429
- context_overlap=0,
430
- context_batch_size=1,
431
- interpolation_factor=1,
432
- audio_sample_rate=16000,
433
- fps=25,
434
- audio_margin=2,
435
- start_idx=0,
436
- **kwargs,
437
- ):
438
- # Default height and width to unet
439
- height = height or self.unet.config.sample_size * self.vae_scale_factor
440
- width = width or self.unet.config.sample_size * self.vae_scale_factor
441
-
442
- device = self._execution_device
443
-
444
- do_classifier_free_guidance = guidance_scale > 1.0
445
-
446
- # Prepare timesteps
447
- self.scheduler.set_timesteps(num_inference_steps, device=device)
448
- timesteps = self.scheduler.timesteps
449
-
450
- batch_size = 1
451
-
452
- reference_control_writer = ReferenceAttentionControl(
453
- self.reference_unet,
454
- do_classifier_free_guidance=do_classifier_free_guidance,
455
- mode="write",
456
- batch_size=batch_size,
457
- fusion_blocks="full",
458
- )
459
- reference_control_reader = ReferenceAttentionControl(
460
- self.denoising_unet,
461
- do_classifier_free_guidance=do_classifier_free_guidance,
462
- mode="read",
463
- batch_size=batch_size,
464
- fusion_blocks="full",
465
- )
466
-
467
- whisper_feature = self.audio_guider.audio2feat(audio_path)
468
-
469
- whisper_chunks = self.audio_guider.feature2chunks(feature_array=whisper_feature, fps=fps)
470
- audio_frame_num = whisper_chunks.shape[0]
471
- audio_fea_final = torch.Tensor(whisper_chunks).to(dtype=self.vae.dtype, device=self.vae.device)
472
- audio_fea_final = audio_fea_final.unsqueeze(0)
473
-
474
- video_length = min(video_length, audio_frame_num)
475
-
476
- num_channels_latents = self.denoising_unet.in_channels
477
- latents = self.prepare_latents_smooth(
478
- batch_size * num_images_per_prompt,
479
- num_channels_latents,
480
- width,
481
- height,
482
- video_length,
483
- audio_fea_final.dtype,
484
- device,
485
- generator,
486
- context_frames
487
- )
488
-
489
- pose_enocder_tensor = self.pose_encoder(poses_tensor)
490
-
491
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
492
-
493
- # Prepare ref image latents
494
- ref_image_tensor = self.ref_image_processor.preprocess(
495
- ref_image, height=height, width=width
496
- ) # (bs, c, width, height)
497
- ref_image_tensor = ref_image_tensor.to(
498
- dtype=self.vae.dtype, device=self.vae.device
499
- )
500
- ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
501
- ref_image_latents = ref_image_latents * 0.18215 # (b , 4, h, w)
502
-
503
- context_scheduler = get_context_scheduler(context_schedule)
504
-
505
- # denoising loop
506
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
507
- context_queue = list(
508
- context_scheduler(
509
- 0,
510
- num_inference_steps,
511
- latents.shape[2],
512
- context_frames,
513
- context_stride,
514
- context_overlap,
515
- )
516
- )
517
-
518
- with self.progress_bar(total=num_inference_steps) as progress_bar:
519
- for t_i, t in enumerate(timesteps):
520
-
521
- noise_pred = torch.zeros(
522
- (
523
- latents.shape[0] * (2 if do_classifier_free_guidance else 1),
524
- *latents.shape[1:],
525
- ),
526
- device=latents.device,
527
- dtype=latents.dtype,
528
- )
529
- counter = torch.zeros(
530
- (1, 1, latents.shape[2], 1, 1),
531
- device=latents.device,
532
- dtype=latents.dtype,
533
- )
534
-
535
- # 1. Forward reference image
536
- if t_i == 0:
537
- self.reference_unet(
538
- ref_image_latents,
539
- torch.zeros_like(t),
540
- encoder_hidden_states=None,
541
- return_dict=False,
542
- )
543
- reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=True)
544
-
545
-
546
- num_context_batches = math.ceil(len(context_queue) / context_batch_size)
547
-
548
- global_context = []
549
- for j in range(num_context_batches):
550
- global_context.append(
551
- context_queue[
552
- j * context_batch_size : (j + 1) * context_batch_size
553
- ]
554
- )
555
-
556
- ## refine
557
- for context in global_context:
558
- new_context = [[0 for _ in range(len(context[c_j]))] for c_j in range(len(context))]
559
- for c_j in range(len(context)):
560
- for c_i in range(len(context[c_j])):
561
- new_context[c_j][c_i] = (context[c_j][c_i] + t_i * 3) % video_length
562
-
563
-
564
- latent_model_input = (
565
- torch.cat([latents[:, :, c] for c in new_context])
566
- .to(device)
567
- .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
568
- )
569
-
570
- audio_latents_cond = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
571
-
572
- audio_latents = torch.cat([torch.zeros_like(audio_latents_cond), audio_latents_cond], 0)
573
- pose_latents_cond = torch.cat([pose_enocder_tensor[:, :, c] for c in new_context]).to(device)
574
- pose_latents = torch.cat([torch.zeros_like(pose_latents_cond), pose_latents_cond], 0)
575
-
576
- latent_model_input = self.scheduler.scale_model_input(
577
- latent_model_input, t
578
- )
579
- b, c, f, h, w = latent_model_input.shape
580
-
581
- pred = self.denoising_unet(
582
- latent_model_input,
583
- t,
584
- encoder_hidden_states=None,
585
- audio_cond_fea=audio_latents if do_classifier_free_guidance else audio_latents_cond,
586
- face_musk_fea=pose_latents if do_classifier_free_guidance else pose_latents_cond,
587
- return_dict=False,
588
- )[0]
589
-
590
- for j, c in enumerate(new_context):
591
- noise_pred[:, :, c] = noise_pred[:, :, c] + pred
592
- counter[:, :, c] = counter[:, :, c] + 1
593
-
594
- # perform guidance
595
- if do_classifier_free_guidance:
596
- noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
597
- noise_pred = noise_pred_uncond + guidance_scale * (
598
- noise_pred_text - noise_pred_uncond
599
- )
600
-
601
- latents = self.scheduler.step(
602
- noise_pred, t, latents, **extra_step_kwargs
603
- ).prev_sample
604
-
605
- if t_i == len(timesteps) - 1 or (
606
- (t_i + 1) > num_warmup_steps and (t_i + 1) % self.scheduler.order == 0
607
- ):
608
- progress_bar.update()
609
-
610
- reference_control_reader.clear()
611
- reference_control_writer.clear()
612
-
613
- if interpolation_factor > 0:
614
- latents = self.interpolate_latents(latents, interpolation_factor, device)
615
- # Post-processing
616
- images = self.decode_latents(latents) # (b, c, f, h, w)
617
-
618
- # Convert to tensor
619
- if output_type == "tensor":
620
- images = torch.from_numpy(images)
621
-
622
- if not return_dict:
623
- return images
624
-
625
- return EchoMimicV2PipelineOutput(videos=images)
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import DiffusionPipeline
9
+ import torch.nn.functional as F
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, is_accelerate_available
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+
24
+ from src.models.mutual_self_attention import ReferenceAttentionControl
25
+ from src.pipelines.context import get_context_scheduler
26
+ from src.pipelines.utils import get_tensor_interpolation_method
27
+
28
+
29
+ @dataclass
30
+ class EchoMimicV2PipelineOutput(BaseOutput):
31
+ videos: Union[torch.Tensor, np.ndarray]
32
+
33
+
34
+ class EchoMimicV2Pipeline(DiffusionPipeline):
35
+
36
+ def __init__(
37
+ self,
38
+ vae,
39
+ reference_unet,
40
+ denoising_unet,
41
+ audio_guider,
42
+ pose_encoder,
43
+ scheduler: Union[
44
+ DDIMScheduler,
45
+ PNDMScheduler,
46
+ LMSDiscreteScheduler,
47
+ EulerDiscreteScheduler,
48
+ EulerAncestralDiscreteScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ ],
51
+ image_proj_model=None,
52
+ tokenizer=None,
53
+ text_encoder=None,
54
+ ):
55
+ super().__init__()
56
+
57
+ self.register_modules(
58
+ vae=vae,
59
+ reference_unet=reference_unet,
60
+ denoising_unet=denoising_unet,
61
+ audio_guider=audio_guider,
62
+ pose_encoder=pose_encoder,
63
+ scheduler=scheduler,
64
+ image_proj_model=image_proj_model,
65
+ tokenizer=tokenizer,
66
+ text_encoder=text_encoder,
67
+ # audio_feature_mapper=audio_feature_mapper
68
+ )
69
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
70
+ self.ref_image_processor = VaeImageProcessor(
71
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
72
+ )
73
+
74
+ def enable_vae_slicing(self):
75
+ self.vae.enable_slicing()
76
+
77
+ def disable_vae_slicing(self):
78
+ self.vae.disable_slicing()
79
+
80
+ def enable_sequential_cpu_offload(self, gpu_id=0):
81
+ if is_accelerate_available():
82
+ from accelerate import cpu_offload
83
+ else:
84
+ raise ImportError("Please install accelerate via `pip install accelerate`")
85
+
86
+ device = torch.device(f"cuda:{gpu_id}") if torch.cuda.is_available() else torch.device("cpu")
87
+
88
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
89
+ if cpu_offloaded_model is not None:
90
+ cpu_offload(cpu_offloaded_model, device)
91
+
92
+ @property
93
+ def _execution_device(self):
94
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
95
+ return self.device
96
+ for module in self.unet.modules():
97
+ if (
98
+ hasattr(module, "_hf_hook")
99
+ and hasattr(module._hf_hook, "execution_device")
100
+ and module._hf_hook.execution_device is not None
101
+ ):
102
+ return torch.device(module._hf_hook.execution_device)
103
+ return self.device
104
+
105
+ def decode_latents(self, latents):
106
+ video_length = latents.shape[2]
107
+ latents = 1 / 0.18215 * latents
108
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
109
+ # video = self.vae.decode(latents).sample
110
+ video = []
111
+ for frame_idx in tqdm(range(latents.shape[0])):
112
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
113
+ video = torch.cat(video)
114
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
115
+ video = (video / 2 + 0.5).clamp(0, 1)
116
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
117
+ video = video.cpu().float().numpy()
118
+ return video
119
+
120
+ def prepare_extra_step_kwargs(self, generator, eta):
121
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
122
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
123
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
124
+ # and should be between [0, 1]
125
+
126
+ accepts_eta = "eta" in set(
127
+ inspect.signature(self.scheduler.step).parameters.keys()
128
+ )
129
+ extra_step_kwargs = {}
130
+ if accepts_eta:
131
+ extra_step_kwargs["eta"] = eta
132
+
133
+ # check if the scheduler accepts generator
134
+ accepts_generator = "generator" in set(
135
+ inspect.signature(self.scheduler.step).parameters.keys()
136
+ )
137
+ if accepts_generator:
138
+ extra_step_kwargs["generator"] = generator
139
+ return extra_step_kwargs
140
+
141
+ def prepare_latents_bp(
142
+ self,
143
+ batch_size,
144
+ num_channels_latents,
145
+ width,
146
+ height,
147
+ video_length,
148
+ dtype,
149
+ device,
150
+ generator,
151
+ latents=None,
152
+ ):
153
+ shape = (
154
+ batch_size,
155
+ num_channels_latents,
156
+ video_length,
157
+ height // self.vae_scale_factor,
158
+ width // self.vae_scale_factor,
159
+ )
160
+ if isinstance(generator, list) and len(generator) != batch_size:
161
+ raise ValueError(
162
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
163
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
164
+ )
165
+
166
+ if latents is None:
167
+ latents = randn_tensor(
168
+ shape, generator=generator, device=device, dtype=dtype
169
+ )
170
+ else:
171
+ latents = latents.to(device)
172
+
173
+ # scale the initial noise by the standard deviation required by the scheduler
174
+ latents = latents * self.scheduler.init_noise_sigma
175
+ return latents
176
+
177
+ def prepare_latents(
178
+ self,
179
+ batch_size,
180
+ num_channels_latents,
181
+ width,
182
+ height,
183
+ video_length,
184
+ dtype,
185
+ device,
186
+ generator,
187
+ context_frame_length
188
+ ):
189
+ shape = (
190
+ batch_size,
191
+ num_channels_latents,
192
+ # context_frame_length,
193
+ video_length,
194
+ height // self.vae_scale_factor,
195
+ width // self.vae_scale_factor,
196
+ )
197
+
198
+ if isinstance(generator, list) and len(generator) != batch_size:
199
+ raise ValueError(
200
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
201
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
202
+ )
203
+
204
+ latents_seg = randn_tensor(
205
+ shape, generator=generator, device=device, dtype=dtype
206
+ )
207
+ latents = latents_seg
208
+
209
+ # scale the initial noise by the standard deviation required by the scheduler
210
+ latents = latents * self.scheduler.init_noise_sigma
211
+ print(f"latents shape:{latents.shape}, video_length:{video_length}")
212
+ return latents
213
+ def prepare_latents_smooth(
214
+ self,
215
+ batch_size,
216
+ num_channels_latents,
217
+ width,
218
+ height,
219
+ video_length,
220
+ dtype,
221
+ device,
222
+ generator,
223
+ context_frame_length
224
+ ):
225
+ shape = (
226
+ batch_size,
227
+ num_channels_latents,
228
+ # context_frame_length,
229
+ video_length,
230
+ height // self.vae_scale_factor,
231
+ width // self.vae_scale_factor,
232
+ )
233
+
234
+ if isinstance(generator, list) and len(generator) != batch_size:
235
+ raise ValueError(
236
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
237
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
238
+ )
239
+
240
+ latents_seg = randn_tensor(
241
+ shape, generator=generator, device=device, dtype=dtype
242
+ )
243
+
244
+ latents = latents_seg
245
+
246
+ latents = torch.clamp(latents_seg, -1.5, 1.5)
247
+
248
+
249
+ # scale the initial noise by the standard deviation required by the scheduler
250
+ latents = latents * self.scheduler.init_noise_sigma
251
+ print(f"latents shape:{latents.shape}, video_length:{video_length}")
252
+
253
+ return latents
254
+
255
+ def _encode_prompt(
256
+ self,
257
+ prompt,
258
+ device,
259
+ num_videos_per_prompt,
260
+ do_classifier_free_guidance,
261
+ negative_prompt,
262
+ ):
263
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
264
+
265
+ text_inputs = self.tokenizer(
266
+ prompt,
267
+ padding="max_length",
268
+ max_length=self.tokenizer.model_max_length,
269
+ truncation=True,
270
+ return_tensors="pt",
271
+ )
272
+ text_input_ids = text_inputs.input_ids
273
+ untruncated_ids = self.tokenizer(
274
+ prompt, padding="longest", return_tensors="pt"
275
+ ).input_ids
276
+
277
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
278
+ text_input_ids, untruncated_ids
279
+ ):
280
+ removed_text = self.tokenizer.batch_decode(
281
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
282
+ )
283
+
284
+ if (
285
+ hasattr(self.text_encoder.config, "use_attention_mask")
286
+ and self.text_encoder.config.use_attention_mask
287
+ ):
288
+ attention_mask = text_inputs.attention_mask.to(device)
289
+ else:
290
+ attention_mask = None
291
+
292
+ text_embeddings = self.text_encoder(
293
+ text_input_ids.to(device),
294
+ attention_mask=attention_mask,
295
+ )
296
+ text_embeddings = text_embeddings[0]
297
+
298
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
299
+ bs_embed, seq_len, _ = text_embeddings.shape
300
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
301
+ text_embeddings = text_embeddings.view(
302
+ bs_embed * num_videos_per_prompt, seq_len, -1
303
+ )
304
+
305
+ # get unconditional embeddings for classifier free guidance
306
+ if do_classifier_free_guidance:
307
+ uncond_tokens: List[str]
308
+ if negative_prompt is None:
309
+ uncond_tokens = [""] * batch_size
310
+ elif type(prompt) is not type(negative_prompt):
311
+ raise TypeError(
312
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
313
+ f" {type(prompt)}."
314
+ )
315
+ elif isinstance(negative_prompt, str):
316
+ uncond_tokens = [negative_prompt]
317
+ elif batch_size != len(negative_prompt):
318
+ raise ValueError(
319
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
320
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
321
+ " the batch size of `prompt`."
322
+ )
323
+ else:
324
+ uncond_tokens = negative_prompt
325
+
326
+ max_length = text_input_ids.shape[-1]
327
+ uncond_input = self.tokenizer(
328
+ uncond_tokens,
329
+ padding="max_length",
330
+ max_length=max_length,
331
+ truncation=True,
332
+ return_tensors="pt",
333
+ )
334
+
335
+ if (
336
+ hasattr(self.text_encoder.config, "use_attention_mask")
337
+ and self.text_encoder.config.use_attention_mask
338
+ ):
339
+ attention_mask = uncond_input.attention_mask.to(device)
340
+ else:
341
+ attention_mask = None
342
+
343
+ uncond_embeddings = self.text_encoder(
344
+ uncond_input.input_ids.to(device),
345
+ attention_mask=attention_mask,
346
+ )
347
+ uncond_embeddings = uncond_embeddings[0]
348
+
349
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
350
+ seq_len = uncond_embeddings.shape[1]
351
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
352
+ uncond_embeddings = uncond_embeddings.view(
353
+ batch_size * num_videos_per_prompt, seq_len, -1
354
+ )
355
+
356
+ # For classifier free guidance, we need to do two forward passes.
357
+ # Here we concatenate the unconditional and text embeddings into a single batch
358
+ # to avoid doing two forward passes
359
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
360
+
361
+ return text_embeddings
362
+
363
+ def interpolate_latents(
364
+ self, latents: torch.Tensor, interpolation_factor: int, device
365
+ ):
366
+ if interpolation_factor < 2:
367
+ return latents
368
+
369
+ new_latents = torch.zeros(
370
+ (
371
+ latents.shape[0],
372
+ latents.shape[1],
373
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
374
+ latents.shape[3],
375
+ latents.shape[4],
376
+ ),
377
+ device=latents.device,
378
+ dtype=latents.dtype,
379
+ )
380
+
381
+ org_video_length = latents.shape[2]
382
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
383
+
384
+ new_index = 0
385
+
386
+ v0 = None
387
+ v1 = None
388
+
389
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
390
+ v0 = latents[:, :, i0, :, :]
391
+ v1 = latents[:, :, i1, :, :]
392
+
393
+ new_latents[:, :, new_index, :, :] = v0
394
+ new_index += 1
395
+
396
+ for f in rate:
397
+ v = get_tensor_interpolation_method()(
398
+ v0.to(device=device), v1.to(device=device), f
399
+ )
400
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
401
+ new_index += 1
402
+
403
+ new_latents[:, :, new_index, :, :] = v1
404
+ new_index += 1
405
+
406
+ return new_latents
407
+
408
+ @torch.no_grad()
409
+ def __call__(
410
+ self,
411
+ ref_image,
412
+ audio_path,
413
+ poses_tensor,
414
+ width,
415
+ height,
416
+ video_length,
417
+ num_inference_steps,
418
+ guidance_scale,
419
+ num_images_per_prompt=1,
420
+ eta: float = 0.0,
421
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
422
+ output_type: Optional[str] = "tensor",
423
+ return_dict: bool = True,
424
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
425
+ callback_steps: Optional[int] = 1,
426
+ context_schedule="uniform",
427
+ context_frames=12,
428
+ context_stride=1,
429
+ context_overlap=0,
430
+ context_batch_size=1,
431
+ interpolation_factor=1,
432
+ audio_sample_rate=16000,
433
+ fps=25,
434
+ audio_margin=2,
435
+ start_idx=0,
436
+ **kwargs,
437
+ ):
438
+ # Default height and width to unet
439
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
440
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
441
+
442
+ device = self._execution_device
443
+
444
+ do_classifier_free_guidance = guidance_scale > 1.0
445
+
446
+ # Prepare timesteps
447
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
448
+ timesteps = self.scheduler.timesteps
449
+
450
+ batch_size = 1
451
+
452
+ reference_control_writer = ReferenceAttentionControl(
453
+ self.reference_unet,
454
+ do_classifier_free_guidance=do_classifier_free_guidance,
455
+ mode="write",
456
+ batch_size=batch_size,
457
+ fusion_blocks="full",
458
+ )
459
+ reference_control_reader = ReferenceAttentionControl(
460
+ self.denoising_unet,
461
+ do_classifier_free_guidance=do_classifier_free_guidance,
462
+ mode="read",
463
+ batch_size=batch_size,
464
+ fusion_blocks="full",
465
+ )
466
+
467
+ whisper_feature = self.audio_guider.audio2feat(audio_path)
468
+
469
+ whisper_chunks = self.audio_guider.feature2chunks(feature_array=whisper_feature, fps=fps)
470
+ audio_frame_num = whisper_chunks.shape[0]
471
+ audio_fea_final = torch.Tensor(whisper_chunks).to(dtype=self.vae.dtype, device=self.vae.device)
472
+ audio_fea_final = audio_fea_final.unsqueeze(0)
473
+
474
+ video_length = min(video_length, audio_frame_num)
475
+
476
+ num_channels_latents = self.denoising_unet.in_channels
477
+ latents = self.prepare_latents_smooth(
478
+ batch_size * num_images_per_prompt,
479
+ num_channels_latents,
480
+ width,
481
+ height,
482
+ video_length,
483
+ audio_fea_final.dtype,
484
+ device,
485
+ generator,
486
+ context_frames
487
+ )
488
+
489
+ pose_enocder_tensor = self.pose_encoder(poses_tensor)
490
+
491
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
492
+
493
+ # Prepare ref image latents
494
+ ref_image_tensor = self.ref_image_processor.preprocess(
495
+ ref_image, height=height, width=width
496
+ ) # (bs, c, width, height)
497
+ ref_image_tensor = ref_image_tensor.to(
498
+ dtype=self.vae.dtype, device=self.vae.device
499
+ )
500
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
501
+ ref_image_latents = ref_image_latents * 0.18215 # (b , 4, h, w)
502
+
503
+ context_scheduler = get_context_scheduler(context_schedule)
504
+
505
+ # denoising loop
506
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
507
+ context_queue = list(
508
+ context_scheduler(
509
+ 0,
510
+ num_inference_steps,
511
+ latents.shape[2],
512
+ context_frames,
513
+ context_stride,
514
+ context_overlap,
515
+ )
516
+ )
517
+
518
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
519
+ for t_i, t in enumerate(timesteps):
520
+
521
+ noise_pred = torch.zeros(
522
+ (
523
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
524
+ *latents.shape[1:],
525
+ ),
526
+ device=latents.device,
527
+ dtype=latents.dtype,
528
+ )
529
+ counter = torch.zeros(
530
+ (1, 1, latents.shape[2], 1, 1),
531
+ device=latents.device,
532
+ dtype=latents.dtype,
533
+ )
534
+
535
+ # 1. Forward reference image
536
+ if t_i == 0:
537
+ self.reference_unet(
538
+ ref_image_latents,
539
+ torch.zeros_like(t),
540
+ encoder_hidden_states=None,
541
+ return_dict=False,
542
+ )
543
+ reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=True)
544
+
545
+
546
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
547
+
548
+ global_context = []
549
+ for j in range(num_context_batches):
550
+ global_context.append(
551
+ context_queue[
552
+ j * context_batch_size : (j + 1) * context_batch_size
553
+ ]
554
+ )
555
+
556
+ ## refine
557
+ for context in global_context:
558
+ new_context = [[0 for _ in range(len(context[c_j]))] for c_j in range(len(context))]
559
+ for c_j in range(len(context)):
560
+ for c_i in range(len(context[c_j])):
561
+ new_context[c_j][c_i] = (context[c_j][c_i] + t_i * 3) % video_length
562
+
563
+
564
+ latent_model_input = (
565
+ torch.cat([latents[:, :, c] for c in new_context])
566
+ .to(device)
567
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
568
+ )
569
+
570
+ audio_latents_cond = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
571
+
572
+ audio_latents = torch.cat([torch.zeros_like(audio_latents_cond), audio_latents_cond], 0)
573
+ pose_latents_cond = torch.cat([pose_enocder_tensor[:, :, c] for c in new_context]).to(device)
574
+ pose_latents = torch.cat([torch.zeros_like(pose_latents_cond), pose_latents_cond], 0)
575
+
576
+ latent_model_input = self.scheduler.scale_model_input(
577
+ latent_model_input, t
578
+ )
579
+ b, c, f, h, w = latent_model_input.shape
580
+
581
+ pred = self.denoising_unet(
582
+ latent_model_input,
583
+ t,
584
+ encoder_hidden_states=None,
585
+ audio_cond_fea=audio_latents if do_classifier_free_guidance else audio_latents_cond,
586
+ face_musk_fea=pose_latents if do_classifier_free_guidance else pose_latents_cond,
587
+ return_dict=False,
588
+ )[0]
589
+
590
+ for j, c in enumerate(new_context):
591
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
592
+ counter[:, :, c] = counter[:, :, c] + 1
593
+
594
+ # perform guidance
595
+ if do_classifier_free_guidance:
596
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
597
+ noise_pred = noise_pred_uncond + guidance_scale * (
598
+ noise_pred_text - noise_pred_uncond
599
+ )
600
+
601
+ latents = self.scheduler.step(
602
+ noise_pred, t, latents, **extra_step_kwargs
603
+ ).prev_sample
604
+
605
+ if t_i == len(timesteps) - 1 or (
606
+ (t_i + 1) > num_warmup_steps and (t_i + 1) % self.scheduler.order == 0
607
+ ):
608
+ progress_bar.update()
609
+
610
+ reference_control_reader.clear()
611
+ reference_control_writer.clear()
612
+
613
+ if interpolation_factor > 0:
614
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
615
+ # Post-processing
616
+ images = self.decode_latents(latents) # (b, c, f, h, w)
617
+
618
+ # Convert to tensor
619
+ if output_type == "tensor":
620
+ images = torch.from_numpy(images)
621
+
622
+ if not return_dict:
623
+ return images
624
+
625
+ return EchoMimicV2PipelineOutput(videos=images)