| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Decoder blocks for WorldEngine modular pipeline.""" |
|
|
| from typing import List, Union |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
|
|
| from diffusers import AutoModel |
| from diffusers.configuration_utils import FrozenDict |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ModularPipeline, |
| PipelineState, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| InputParam, |
| OutputParam, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class WorldEngineDecodeStep(ModularPipelineBlocks): |
| """Decodes denoised latents back to RGB image using VAE.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("vae", AutoModel), |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict( |
| { |
| "vae_scale_factor": 16, |
| "do_normalize": False, |
| "do_convert_rgb": True, |
| } |
| ), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def description(self) -> str: |
| return "Decodes denoised latents to RGB image using the VAE decoder" |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Denoised latent tensor [1, 1, C, H, W]", |
| ), |
| InputParam( |
| "output_type", |
| default="pil", |
| description="The output format for the generated images (pil, latent, pt, or np)", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "images", |
| type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray], |
| description="Decoded RGB image in requested output format", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| latents = block_state.latents |
| output_type = block_state.output_type or "pil" |
|
|
| if output_type == "latent": |
| block_state.images = latents |
| else: |
| |
| |
| |
| image = components.vae.decode(latents.squeeze(1)) |
|
|
| |
| if output_type == "pt": |
| block_state.images = image |
| elif output_type == "np": |
| block_state.images = image.cpu().numpy() |
| else: |
| block_state.images = PIL.Image.fromarray(image.cpu().numpy()) |
|
|
| |
| block_state.latents = None |
| self.set_block_state(state, block_state) |
| return components, state |
|
|