Spaces:
Running on Zero
Running on Zero
| import torch, math | |
| from PIL import Image | |
| from typing import Union | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| import numpy as np | |
| from typing import Union, List, Optional, Tuple, Iterable, Dict | |
| from ..core.device.npu_compatible_device import get_device_type | |
| from ..diffusion import FlowMatchScheduler | |
| from ..core import ModelConfig, gradient_checkpoint_forward | |
| from ..core.data.operators import ImageCropAndResize | |
| from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput | |
| from ..utils.lora import merge_lora | |
| from transformers import AutoTokenizer | |
| from ..models.z_image_text_encoder import ZImageTextEncoder | |
| from ..models.z_image_dit import ZImageDiT | |
| from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder | |
| from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M | |
| from ..models.z_image_controlnet import ZImageControlNet | |
| from ..models.siglip2_image_encoder import Siglip2ImageEncoder | |
| from ..models.dinov3_image_encoder import DINOv3ImageEncoder | |
| from ..models.z_image_image2lora import ZImageImage2LoRAModel | |
| class ZImagePipeline(BasePipeline): | |
| def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): | |
| super().__init__( | |
| device=device, torch_dtype=torch_dtype, | |
| height_division_factor=16, width_division_factor=16, | |
| ) | |
| self.scheduler = FlowMatchScheduler("Z-Image") | |
| self.text_encoder: ZImageTextEncoder = None | |
| self.dit: ZImageDiT = None | |
| self.vae_encoder: FluxVAEEncoder = None | |
| self.vae_decoder: FluxVAEDecoder = None | |
| self.image_encoder: Siglip2ImageEncoder428M = None | |
| self.controlnet: ZImageControlNet = None | |
| self.siglip2_image_encoder: Siglip2ImageEncoder = None | |
| self.dinov3_image_encoder: DINOv3ImageEncoder = None | |
| self.image2lora_style: ZImageImage2LoRAModel = None | |
| self.tokenizer: AutoTokenizer = None | |
| self.in_iteration_models = ("dit", "controlnet") | |
| self.units = [ | |
| ZImageUnit_ShapeChecker(), | |
| ZImageUnit_PromptEmbedder(), | |
| ZImageUnit_NoiseInitializer(), | |
| ZImageUnit_InputImageEmbedder(), | |
| ZImageUnit_EditImageAutoResize(), | |
| ZImageUnit_EditImageEmbedderVAE(), | |
| ZImageUnit_EditImageEmbedderSiglip(), | |
| ZImageUnit_PAIControlNet(), | |
| ] | |
| self.model_fn = model_fn_z_image | |
| def from_pretrained( | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| device: Union[str, torch.device] = get_device_type(), | |
| model_configs: list[ModelConfig] = [], | |
| tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), | |
| vram_limit: float = None, | |
| ): | |
| # Initialize pipeline | |
| pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) | |
| model_pool = pipe.download_and_load_models(model_configs, vram_limit) | |
| # Fetch models | |
| pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") | |
| pipe.dit = model_pool.fetch_model("z_image_dit") | |
| pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") | |
| pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") | |
| pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") | |
| pipe.controlnet = model_pool.fetch_model("z_image_controlnet") | |
| pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") | |
| pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") | |
| pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") | |
| if tokenizer_config is not None: | |
| tokenizer_config.download_if_necessary() | |
| pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) | |
| # VRAM Management | |
| pipe.vram_management_enabled = pipe.check_vram_management_state() | |
| return pipe | |
| def __call__( | |
| self, | |
| # Prompt | |
| prompt: str, | |
| negative_prompt: str = "", | |
| cfg_scale: float = 1.0, | |
| # Image | |
| input_image: Image.Image = None, | |
| denoising_strength: float = 1.0, | |
| # Edit | |
| edit_image: Image.Image = None, | |
| edit_image_auto_resize: bool = True, | |
| # Shape | |
| height: int = 1024, | |
| width: int = 1024, | |
| # Randomness | |
| seed: int = None, | |
| rand_device: str = "cpu", | |
| # Steps | |
| num_inference_steps: int = 8, | |
| sigma_shift: float = None, | |
| # ControlNet | |
| controlnet_inputs: List[ControlNetInput] = None, | |
| # Image to LoRA | |
| image2lora_images: List[Image.Image] = None, | |
| positive_only_lora: Dict[str, torch.Tensor] = None, | |
| # Progress bar | |
| progress_bar_cmd = tqdm, | |
| ): | |
| # Scheduler | |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) | |
| # Parameters | |
| inputs_posi = { | |
| "prompt": prompt, | |
| } | |
| inputs_nega = { | |
| "negative_prompt": negative_prompt, | |
| } | |
| inputs_shared = { | |
| "cfg_scale": cfg_scale, | |
| "input_image": input_image, "denoising_strength": denoising_strength, | |
| "height": height, "width": width, | |
| "seed": seed, "rand_device": rand_device, | |
| "num_inference_steps": num_inference_steps, | |
| "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, | |
| "controlnet_inputs": controlnet_inputs, | |
| "image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora, | |
| } | |
| for unit in self.units: | |
| inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |
| # Denoise | |
| self.load_models_to_device(self.in_iteration_models) | |
| models = {name: getattr(self, name) for name in self.in_iteration_models} | |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): | |
| timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) | |
| noise_pred = self.cfg_guided_model_fn( | |
| self.model_fn, cfg_scale, | |
| inputs_shared, inputs_posi, inputs_nega, | |
| **models, timestep=timestep, progress_id=progress_id | |
| ) | |
| inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) | |
| # Decode | |
| self.load_models_to_device(['vae_decoder']) | |
| image = self.vae_decoder(inputs_shared["latents"]) | |
| image = self.vae_output_to_image(image) | |
| self.load_models_to_device([]) | |
| return image | |
| class ZImageUnit_ShapeChecker(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width"), | |
| output_params=("height", "width"), | |
| ) | |
| def process(self, pipe: ZImagePipeline, height, width): | |
| height, width = pipe.check_resize_height_width(height, width) | |
| return {"height": height, "width": width} | |
| class ZImageUnit_PromptEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params=("edit_image",), | |
| input_params_posi={"prompt": "prompt"}, | |
| input_params_nega={"prompt": "negative_prompt"}, | |
| output_params=("prompt_embeds",), | |
| onload_model_names=("text_encoder",) | |
| ) | |
| def encode_prompt( | |
| self, | |
| pipe, | |
| prompt: Union[str, List[str]], | |
| device: Optional[torch.device] = None, | |
| max_sequence_length: int = 512, | |
| ) -> List[torch.FloatTensor]: | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| for i, prompt_item in enumerate(prompt): | |
| messages = [ | |
| {"role": "user", "content": prompt_item}, | |
| ] | |
| prompt_item = pipe.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=True, | |
| ) | |
| prompt[i] = prompt_item | |
| text_inputs = pipe.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| prompt_masks = text_inputs.attention_mask.to(device).bool() | |
| prompt_embeds = pipe.text_encoder( | |
| input_ids=text_input_ids, | |
| attention_mask=prompt_masks, | |
| output_hidden_states=True, | |
| ).hidden_states[-2] | |
| embeddings_list = [] | |
| for i in range(len(prompt_embeds)): | |
| embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) | |
| return embeddings_list | |
| def encode_prompt_omni( | |
| self, | |
| pipe, | |
| prompt: Union[str, List[str]], | |
| edit_image=None, | |
| device: Optional[torch.device] = None, | |
| max_sequence_length: int = 512, | |
| ) -> List[torch.FloatTensor]: | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| if edit_image is None: | |
| num_condition_images = 0 | |
| elif isinstance(edit_image, list): | |
| num_condition_images = len(edit_image) | |
| else: | |
| num_condition_images = 1 | |
| for i, prompt_item in enumerate(prompt): | |
| if num_condition_images == 0: | |
| prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] | |
| elif num_condition_images > 0: | |
| prompt_list = ["<|im_start|>user\n<|vision_start|>"] | |
| prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) | |
| prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] | |
| prompt_list += ["<|vision_end|><|im_end|>"] | |
| prompt[i] = prompt_list | |
| flattened_prompt = [] | |
| prompt_list_lengths = [] | |
| for i in range(len(prompt)): | |
| prompt_list_lengths.append(len(prompt[i])) | |
| flattened_prompt.extend(prompt[i]) | |
| text_inputs = pipe.tokenizer( | |
| flattened_prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| prompt_masks = text_inputs.attention_mask.to(device).bool() | |
| prompt_embeds = pipe.text_encoder( | |
| input_ids=text_input_ids, | |
| attention_mask=prompt_masks, | |
| output_hidden_states=True, | |
| ).hidden_states[-2] | |
| embeddings_list = [] | |
| start_idx = 0 | |
| for i in range(len(prompt_list_lengths)): | |
| batch_embeddings = [] | |
| end_idx = start_idx + prompt_list_lengths[i] | |
| for j in range(start_idx, end_idx): | |
| batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) | |
| embeddings_list.append(batch_embeddings) | |
| start_idx = end_idx | |
| return embeddings_list | |
| def process(self, pipe: ZImagePipeline, prompt, edit_image): | |
| pipe.load_models_to_device(self.onload_model_names) | |
| if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None: | |
| # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. | |
| # We determine which encoding method to use based on the model architecture. | |
| # If you are using two-stage split training, | |
| # please use `--offload_models` instead of skipping the DiT model loading. | |
| prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device) | |
| else: | |
| prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) | |
| return {"prompt_embeds": prompt_embeds} | |
| class ZImageUnit_NoiseInitializer(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "seed", "rand_device"), | |
| output_params=("noise",), | |
| ) | |
| def process(self, pipe: ZImagePipeline, height, width, seed, rand_device): | |
| noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) | |
| return {"noise": noise} | |
| class ZImageUnit_InputImageEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "noise"), | |
| output_params=("latents", "input_latents"), | |
| onload_model_names=("vae_encoder",) | |
| ) | |
| def process(self, pipe: ZImagePipeline, input_image, noise): | |
| if input_image is None: | |
| return {"latents": noise, "input_latents": None} | |
| pipe.load_models_to_device(['vae']) | |
| image = pipe.preprocess_image(input_image) | |
| input_latents = pipe.vae_encoder(image) | |
| if pipe.scheduler.training: | |
| return {"latents": noise, "input_latents": input_latents} | |
| else: | |
| latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) | |
| return {"latents": latents, "input_latents": input_latents} | |
| class ZImageUnit_EditImageAutoResize(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("edit_image", "edit_image_auto_resize"), | |
| output_params=("edit_image",), | |
| ) | |
| def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize): | |
| if edit_image is None: | |
| return {} | |
| if edit_image_auto_resize is None or not edit_image_auto_resize: | |
| return {} | |
| operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) | |
| if not isinstance(edit_image, list): | |
| edit_image = [edit_image] | |
| edit_image = [operator(i) for i in edit_image] | |
| return {"edit_image": edit_image} | |
| class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("edit_image",), | |
| output_params=("image_embeds",), | |
| onload_model_names=("image_encoder",) | |
| ) | |
| def process(self, pipe: ZImagePipeline, edit_image): | |
| if edit_image is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| if not isinstance(edit_image, list): | |
| edit_image = [edit_image] | |
| image_emb = [] | |
| for image_ in edit_image: | |
| image_emb.append(pipe.image_encoder(image_, device=pipe.device)) | |
| return {"image_embeds": image_emb} | |
| class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("edit_image",), | |
| output_params=("image_latents",), | |
| onload_model_names=("vae_encoder",) | |
| ) | |
| def process(self, pipe: ZImagePipeline, edit_image): | |
| if edit_image is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| if not isinstance(edit_image, list): | |
| edit_image = [edit_image] | |
| image_latents = [] | |
| for image_ in edit_image: | |
| image_ = pipe.preprocess_image(image_) | |
| image_latents.append(pipe.vae_encoder(image_)) | |
| return {"image_latents": image_latents} | |
| class ZImageUnit_PAIControlNet(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("controlnet_inputs", "height", "width"), | |
| output_params=("control_context", "control_scale"), | |
| onload_model_names=("vae_encoder",) | |
| ) | |
| def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width): | |
| if controlnet_inputs is None: | |
| return {} | |
| if len(controlnet_inputs) != 1: | |
| print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.") | |
| controlnet_input = controlnet_inputs[0] | |
| pipe.load_models_to_device(self.onload_model_names) | |
| control_image = controlnet_input.image | |
| if control_image is not None: | |
| control_image = pipe.preprocess_image(control_image) | |
| control_latents = pipe.vae_encoder(control_image) | |
| else: | |
| control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1 | |
| inpaint_mask = controlnet_input.inpaint_mask | |
| if inpaint_mask is not None: | |
| inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1) | |
| inpaint_image = controlnet_input.inpaint_image | |
| inpaint_image = pipe.preprocess_image(inpaint_image) | |
| inpaint_image = inpaint_image * (inpaint_mask < 0.5) | |
| inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1] | |
| else: | |
| inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) | |
| inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device) | |
| inpaint_latent = pipe.vae_encoder(inpaint_image) | |
| control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1) | |
| control_context = rearrange(control_context, "B C H W -> B C 1 H W") | |
| return {"control_context": control_context, "control_scale": controlnet_input.scale} | |
| def model_fn_z_image( | |
| dit: ZImageDiT, | |
| controlnet: ZImageControlNet = None, | |
| latents=None, | |
| timestep=None, | |
| prompt_embeds=None, | |
| image_embeds=None, | |
| image_latents=None, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| **kwargs, | |
| ): | |
| # Due to the complex and verbose codebase of Z-Image, | |
| # we are temporarily using this inelegant structure. | |
| # We will refactor this part in the future (if time permits). | |
| if dit.siglip_embedder is None: | |
| return model_fn_z_image_turbo( | |
| dit, | |
| controlnet=controlnet, | |
| latents=latents, | |
| timestep=timestep, | |
| prompt_embeds=prompt_embeds, | |
| image_embeds=image_embeds, | |
| image_latents=image_latents, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| **kwargs, | |
| ) | |
| latents = [rearrange(latents, "B C H W -> C B H W")] | |
| if dit.siglip_embedder is not None: | |
| if image_latents is not None: | |
| image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents] | |
| latents = [image_latents + latents] | |
| image_noise_mask = [[0] * len(image_latents) + [1]] | |
| else: | |
| latents = [latents] | |
| image_noise_mask = [[1]] | |
| image_embeds = [image_embeds] | |
| else: | |
| image_noise_mask = None | |
| timestep = (1000 - timestep) / 1000 | |
| model_output = dit( | |
| latents, | |
| timestep, | |
| prompt_embeds, | |
| siglip_feats=image_embeds, | |
| image_noise_mask=image_noise_mask, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| )[0] | |
| model_output = -model_output | |
| model_output = rearrange(model_output, "C B H W -> B C H W") | |
| return model_output | |
| class ZImageUnit_Image2LoRAEncode(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("image2lora_images",), | |
| output_params=("image2lora_x",), | |
| onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), | |
| ) | |
| from ..core.data.operators import ImageCropAndResize | |
| self.processor_highres = ImageCropAndResize(height=1024, width=1024) | |
| def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): | |
| pipe.load_models_to_device(["siglip2_image_encoder"]) | |
| embs = [] | |
| for image in images: | |
| image = self.processor_highres(image) | |
| embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) | |
| embs = torch.stack(embs) | |
| return embs | |
| def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): | |
| pipe.load_models_to_device(["dinov3_image_encoder"]) | |
| embs = [] | |
| for image in images: | |
| image = self.processor_highres(image) | |
| embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) | |
| embs = torch.stack(embs) | |
| return embs | |
| def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): | |
| if images is None: | |
| return {} | |
| if not isinstance(images, list): | |
| images = [images] | |
| embs_siglip2 = self.encode_images_using_siglip2(pipe, images) | |
| embs_dinov3 = self.encode_images_using_dinov3(pipe, images) | |
| x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) | |
| return x | |
| def process(self, pipe: ZImagePipeline, image2lora_images): | |
| if image2lora_images is None: | |
| return {} | |
| x = self.encode_images(pipe, image2lora_images) | |
| return {"image2lora_x": x} | |
| class ZImageUnit_Image2LoRADecode(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("image2lora_x",), | |
| output_params=("lora",), | |
| onload_model_names=("image2lora_style",), | |
| ) | |
| def process(self, pipe: ZImagePipeline, image2lora_x): | |
| if image2lora_x is None: | |
| return {} | |
| loras = [] | |
| if pipe.image2lora_style is not None: | |
| pipe.load_models_to_device(["image2lora_style"]) | |
| for x in image2lora_x: | |
| loras.append(pipe.image2lora_style(x=x, residual=None)) | |
| lora = merge_lora(loras, alpha=1 / len(image2lora_x)) | |
| return {"lora": lora} | |
| def model_fn_z_image_turbo( | |
| dit: ZImageDiT, | |
| controlnet: ZImageControlNet = None, | |
| latents=None, | |
| timestep=None, | |
| prompt_embeds=None, | |
| image_embeds=None, | |
| image_latents=None, | |
| control_context=None, | |
| control_scale=None, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| **kwargs, | |
| ): | |
| while isinstance(prompt_embeds, list): | |
| prompt_embeds = prompt_embeds[0] | |
| while isinstance(latents, list): | |
| latents = latents[0] | |
| while isinstance(image_embeds, list): | |
| image_embeds = image_embeds[0] | |
| # Timestep | |
| timestep = 1000 - timestep | |
| t_noisy = dit.t_embedder(timestep) | |
| t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000) | |
| # Patchify | |
| latents = rearrange(latents, "B C H W -> C B H W") | |
| x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds]) | |
| x = x[0] | |
| cap_feats = cap_feats[0] | |
| # Noise refine | |
| x = dit.all_x_embedder["2-1"](x) | |
| x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device) | |
| x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) | |
| x = rearrange(x, "L C -> 1 L C") | |
| x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") | |
| if control_context is not None: | |
| kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) | |
| refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( | |
| dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1, | |
| use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| ) | |
| for layer_id, layer in enumerate(dit.noise_refiner): | |
| x = gradient_checkpoint_forward( | |
| layer, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| x=x, | |
| attn_mask=None, | |
| freqs_cis=x_freqs_cis, | |
| adaln_input=t_noisy, | |
| ) | |
| if control_context is not None: | |
| x = x + refiner_hints[layer_id] * control_scale | |
| # Prompt refine | |
| cap_feats = dit.cap_embedder(cap_feats) | |
| cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device) | |
| cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0)) | |
| cap_feats = rearrange(cap_feats, "L C -> 1 L C") | |
| cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C") | |
| for layer in dit.context_refiner: | |
| cap_feats = gradient_checkpoint_forward( | |
| layer, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| x=cap_feats, | |
| attn_mask=None, | |
| freqs_cis=cap_freqs_cis, | |
| ) | |
| # Unified | |
| unified = torch.cat([x, cap_feats], dim=1) | |
| unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) | |
| if control_context is not None: | |
| kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) | |
| hints = controlnet.forward_layers( | |
| unified, cap_feats, control_context, control_context_item_seqlens, kwargs, | |
| use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| ) | |
| for layer_id, layer in enumerate(dit.layers): | |
| unified = gradient_checkpoint_forward( | |
| layer, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| x=unified, | |
| attn_mask=None, | |
| freqs_cis=unified_freqs_cis, | |
| adaln_input=t_noisy, | |
| ) | |
| if control_context is not None: | |
| if layer_id in controlnet.control_layers_mapping: | |
| unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale | |
| # Output | |
| unified = dit.all_final_layer["2-1"](unified, t_noisy) | |
| x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0] | |
| x = rearrange(x, "C B H W -> B C H W") | |
| x = -x | |
| return x | |