| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Dict, List |
|
|
| import random |
| import torch |
| from torchvision.transforms import v2 |
|
|
| from diffusers.utils import logging |
| from diffusers import ModularPipeline, ModularPipelineBlocks |
| from diffusers.modular_pipelines import PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class MatrixGameWanImageEncoderStep(ModularPipelineBlocks): |
| model_name = "MatrixGameWan" |
|
|
| @property |
| def description(self) -> str: |
| return "Image Encoder step that generate image_embeddings to guide the video generation" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| "image_encoder", |
| CLIPVisionModelWithProjection, |
| repo="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", |
| ), |
| ComponentSpec( |
| "image_processor", |
| CLIPImageProcessor, |
| repo="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", |
| subfolder="image_processor" |
| ), |
| ] |
|
|
| @property |
| def expected_configs(self) -> List[ConfigSpec]: |
| return [] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("image"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "image_embeds", |
| type_hint=torch.Tensor, |
| description="image embeddings used to guide the image generation", |
| ) |
| ] |
|
|
| def encode_image(self, components, image): |
| device = components._execution_device |
| image = components.image_processor(images=image, return_tensors="pt").to(device) |
| image_embeds = components.image_encoder(**image, output_hidden_states=True) |
| return image_embeds.hidden_states[-2] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| |
| block_state = self.get_block_state(state) |
| block_state.device = components._execution_device |
| |
| |
| block_state.image_embeds = self.encode_image(components, block_state.image) |
|
|
| |
| self.set_block_state(state, block_state) |
| return components, state |
|
|