Spaces:
Running on Zero
Running on Zero
| import torch, math | |
| from PIL import Image | |
| from typing import Union | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| import numpy as np | |
| from math import prod | |
| from ..core.device.npu_compatible_device import get_device_type | |
| from ..diffusion import FlowMatchScheduler | |
| from ..core import ModelConfig, gradient_checkpoint_forward | |
| from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput | |
| from ..utils.lora.merge import merge_lora | |
| from ..models.qwen_image_dit import QwenImageDiT | |
| from ..models.qwen_image_text_encoder import QwenImageTextEncoder | |
| from ..models.qwen_image_vae import QwenImageVAE | |
| from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet | |
| from ..models.siglip2_image_encoder import Siglip2ImageEncoder | |
| from ..models.dinov3_image_encoder import DINOv3ImageEncoder | |
| from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel | |
| class QwenImagePipeline(BasePipeline): | |
| def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): | |
| super().__init__( | |
| device=device, torch_dtype=torch_dtype, | |
| height_division_factor=16, width_division_factor=16, | |
| ) | |
| from transformers import Qwen2Tokenizer, Qwen2VLProcessor | |
| self.scheduler = FlowMatchScheduler("Qwen-Image") | |
| self.text_encoder: QwenImageTextEncoder = None | |
| self.dit: QwenImageDiT = None | |
| self.vae: QwenImageVAE = None | |
| self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None | |
| self.tokenizer: Qwen2Tokenizer = None | |
| self.siglip2_image_encoder: Siglip2ImageEncoder = None | |
| self.dinov3_image_encoder: DINOv3ImageEncoder = None | |
| self.image2lora_style: QwenImageImage2LoRAModel = None | |
| self.image2lora_coarse: QwenImageImage2LoRAModel = None | |
| self.image2lora_fine: QwenImageImage2LoRAModel = None | |
| self.processor: Qwen2VLProcessor = None | |
| self.in_iteration_models = ("dit", "blockwise_controlnet") | |
| self.units = [ | |
| QwenImageUnit_ShapeChecker(), | |
| QwenImageUnit_NoiseInitializer(), | |
| QwenImageUnit_InputImageEmbedder(), | |
| QwenImageUnit_Inpaint(), | |
| QwenImageUnit_EditImageEmbedder(), | |
| QwenImageUnit_LayerInputImageEmbedder(), | |
| QwenImageUnit_ContextImageEmbedder(), | |
| QwenImageUnit_PromptEmbedder(), | |
| QwenImageUnit_EntityControl(), | |
| QwenImageUnit_BlockwiseControlNet(), | |
| ] | |
| self.model_fn = model_fn_qwen_image | |
| def from_pretrained( | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| device: Union[str, torch.device] = get_device_type(), | |
| model_configs: list[ModelConfig] = [], | |
| tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), | |
| processor_config: ModelConfig = None, | |
| vram_limit: float = None, | |
| ): | |
| # Initialize pipeline | |
| pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype) | |
| model_pool = pipe.download_and_load_models(model_configs, vram_limit) | |
| # Fetch models | |
| pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder") | |
| pipe.dit = model_pool.fetch_model("qwen_image_dit") | |
| pipe.vae = model_pool.fetch_model("qwen_image_vae") | |
| pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all")) | |
| if tokenizer_config is not None: | |
| tokenizer_config.download_if_necessary() | |
| from transformers import Qwen2Tokenizer | |
| pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) | |
| if processor_config is not None: | |
| processor_config.download_if_necessary() | |
| from transformers import Qwen2VLProcessor | |
| pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) | |
| pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") | |
| pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") | |
| pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style") | |
| pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse") | |
| pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine") | |
| # VRAM Management | |
| pipe.vram_management_enabled = pipe.check_vram_management_state() | |
| return pipe | |
| def __call__( | |
| self, | |
| # Prompt | |
| prompt: str, | |
| negative_prompt: str = "", | |
| cfg_scale: float = 4.0, | |
| # Image | |
| input_image: Image.Image = None, | |
| denoising_strength: float = 1.0, | |
| # Inpaint | |
| inpaint_mask: Image.Image = None, | |
| inpaint_blur_size: int = None, | |
| inpaint_blur_sigma: float = None, | |
| # Shape | |
| height: int = 1328, | |
| width: int = 1328, | |
| # Randomness | |
| seed: int = None, | |
| rand_device: str = "cpu", | |
| # Steps | |
| num_inference_steps: int = 30, | |
| exponential_shift_mu: float = None, | |
| # Blockwise ControlNet | |
| blockwise_controlnet_inputs: list[ControlNetInput] = None, | |
| # EliGen | |
| eligen_entity_prompts: list[str] = None, | |
| eligen_entity_masks: list[Image.Image] = None, | |
| eligen_enable_on_negative: bool = False, | |
| # Qwen-Image-Edit | |
| edit_image: Image.Image = None, | |
| edit_image_auto_resize: bool = True, | |
| edit_rope_interpolation: bool = False, | |
| # Qwen-Image-Edit-2511 | |
| zero_cond_t: bool = False, | |
| # Qwen-Image-Layered | |
| layer_input_image: Image.Image = None, | |
| layer_num: int = None, | |
| # In-context control | |
| context_image: Image.Image = None, | |
| # Tile | |
| tiled: bool = False, | |
| tile_size: int = 128, | |
| tile_stride: int = 64, | |
| # Progress bar | |
| progress_bar_cmd = tqdm, | |
| ): | |
| # Scheduler | |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) | |
| # Parameters | |
| inputs_posi = { | |
| "prompt": prompt, | |
| } | |
| inputs_nega = { | |
| "negative_prompt": negative_prompt, | |
| } | |
| inputs_shared = { | |
| "cfg_scale": cfg_scale, | |
| "input_image": input_image, "denoising_strength": denoising_strength, | |
| "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, | |
| "height": height, "width": width, | |
| "seed": seed, "rand_device": rand_device, | |
| "num_inference_steps": num_inference_steps, | |
| "blockwise_controlnet_inputs": blockwise_controlnet_inputs, | |
| "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, | |
| "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, | |
| "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, | |
| "context_image": context_image, | |
| "zero_cond_t": zero_cond_t, | |
| "layer_input_image": layer_input_image, | |
| "layer_num": layer_num, | |
| } | |
| for unit in self.units: | |
| inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |
| # Denoise | |
| self.load_models_to_device(self.in_iteration_models) | |
| models = {name: getattr(self, name) for name in self.in_iteration_models} | |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): | |
| timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) | |
| noise_pred = self.cfg_guided_model_fn( | |
| self.model_fn, cfg_scale, | |
| inputs_shared, inputs_posi, inputs_nega, | |
| **models, timestep=timestep, progress_id=progress_id | |
| ) | |
| inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) | |
| # Decode | |
| self.load_models_to_device(['vae']) | |
| image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| if layer_num is None: | |
| image = self.vae_output_to_image(image) | |
| else: | |
| image = [self.vae_output_to_image(i, pattern="C H W") for i in image] | |
| self.load_models_to_device([]) | |
| return image | |
| class QwenImageBlockwiseMultiControlNet(torch.nn.Module): | |
| def __init__(self, models: list[QwenImageBlockWiseControlNet]): | |
| super().__init__() | |
| if not isinstance(models, list): | |
| models = [models] | |
| self.models = torch.nn.ModuleList(models) | |
| for model in models: | |
| if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"): | |
| self.vram_management_enabled = True | |
| def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs): | |
| processed_conditionings = [] | |
| for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): | |
| conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) | |
| model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning) | |
| processed_conditionings.append(model_output) | |
| return processed_conditionings | |
| def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs): | |
| res = 0 | |
| for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): | |
| progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) | |
| if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4): | |
| continue | |
| model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id) | |
| res = res + model_output * controlnet_input.scale | |
| return res | |
| class QwenImageUnit_ShapeChecker(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width"), | |
| output_params=("height", "width"), | |
| ) | |
| def process(self, pipe: QwenImagePipeline, height, width): | |
| height, width = pipe.check_resize_height_width(height, width) | |
| return {"height": height, "width": width} | |
| class QwenImageUnit_NoiseInitializer(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "seed", "rand_device", "layer_num"), | |
| output_params=("noise",), | |
| ) | |
| def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num): | |
| if layer_num is None: | |
| noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) | |
| else: | |
| noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) | |
| return {"noise": noise} | |
| class QwenImageUnit_InputImageEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), | |
| output_params=("latents", "input_latents"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): | |
| if input_image is None: | |
| return {"latents": noise, "input_latents": None} | |
| pipe.load_models_to_device(['vae']) | |
| if isinstance(input_image, list): | |
| input_latents = [] | |
| for image in input_image: | |
| image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)) | |
| input_latents = torch.concat(input_latents, dim=0) | |
| else: | |
| image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| if pipe.scheduler.training: | |
| return {"latents": noise, "input_latents": input_latents} | |
| else: | |
| latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) | |
| return {"latents": latents, "input_latents": input_latents} | |
| class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"), | |
| output_params=("layer_input_latents",), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride): | |
| if layer_input_image is None: | |
| return {} | |
| pipe.load_models_to_device(['vae']) | |
| image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| return {"layer_input_latents": latents} | |
| class QwenImageUnit_Inpaint(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), | |
| output_params=("inpaint_mask",), | |
| ) | |
| def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): | |
| if inpaint_mask is None: | |
| return {} | |
| inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1) | |
| inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) | |
| if inpaint_blur_size is not None and inpaint_blur_sigma is not None: | |
| from torchvision.transforms import GaussianBlur | |
| blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) | |
| inpaint_mask = blur(inpaint_mask) | |
| return {"inpaint_mask": inpaint_mask} | |
| class QwenImageUnit_PromptEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params_posi={"prompt": "prompt"}, | |
| input_params_nega={"prompt": "negative_prompt"}, | |
| input_params=("edit_image",), | |
| output_params=("prompt_emb", "prompt_emb_mask"), | |
| onload_model_names=("text_encoder",) | |
| ) | |
| def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | |
| return split_result | |
| def calculate_dimensions(self, target_area, ratio): | |
| width = math.sqrt(target_area * ratio) | |
| height = width / ratio | |
| width = round(width / 32) * 32 | |
| height = round(height / 32) * 32 | |
| return width, height | |
| def resize_image(self, image, target_area=384*384): | |
| width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1]) | |
| return image.resize((width, height)) | |
| def encode_prompt(self, pipe: QwenImagePipeline, prompt): | |
| template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | |
| drop_idx = 34 | |
| txt = [template.format(e) for e in prompt] | |
| model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) | |
| if model_inputs.input_ids.shape[1] >= 1024: | |
| print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") | |
| hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1] | |
| split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| return split_hidden_states | |
| def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): | |
| template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" | |
| drop_idx = 64 | |
| txt = [template.format(e) for e in prompt] | |
| model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) | |
| hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] | |
| split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| return split_hidden_states | |
| def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image): | |
| template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | |
| drop_idx = 64 | |
| img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" | |
| base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))]) | |
| txt = [template.format(base_img_prompt + e) for e in prompt] | |
| edit_image = [self.resize_image(image) for image in edit_image] | |
| model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) | |
| hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] | |
| split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| return split_hidden_states | |
| def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: | |
| pipe.load_models_to_device(self.onload_model_names) | |
| if pipe.text_encoder is not None: | |
| prompt = [prompt] | |
| if edit_image is None: | |
| split_hidden_states = self.encode_prompt(pipe, prompt) | |
| elif isinstance(edit_image, Image.Image): | |
| split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image) | |
| else: | |
| split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image) | |
| attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] | |
| max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) | |
| encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) | |
| prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} | |
| else: | |
| return {} | |
| class QwenImageUnit_EntityControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| take_over=True, | |
| input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), | |
| output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), | |
| onload_model_names=("text_encoder",) | |
| ) | |
| def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | |
| return split_result | |
| def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: | |
| if pipe.text_encoder is not None: | |
| prompt = [prompt] | |
| template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | |
| drop_idx = 34 | |
| txt = [template.format(e) for e in prompt] | |
| txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) | |
| hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] | |
| split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] | |
| max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) | |
| encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) | |
| prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} | |
| else: | |
| return {} | |
| def preprocess_masks(self, pipe, masks, height, width, dim): | |
| out_masks = [] | |
| for mask in masks: | |
| mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 | |
| mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| out_masks.append(mask) | |
| return out_masks | |
| def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): | |
| entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) | |
| entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w | |
| prompt_embs, prompt_emb_masks = [], [] | |
| for entity_prompt in entity_prompts: | |
| prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) | |
| prompt_embs.append(prompt_emb_dict['prompt_emb']) | |
| prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) | |
| return prompt_embs, prompt_emb_masks, entity_masks | |
| def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): | |
| entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) | |
| if enable_eligen_on_negative and cfg_scale != 1.0: | |
| entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) | |
| entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) | |
| entity_masks_nega = entity_masks_posi | |
| else: | |
| entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None | |
| eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} | |
| eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} | |
| return eligen_kwargs_posi, eligen_kwargs_nega | |
| def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): | |
| eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) | |
| if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| pipe.load_models_to_device(self.onload_model_names) | |
| eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) | |
| eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, | |
| eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], | |
| eligen_enable_on_negative, inputs_shared["cfg_scale"]) | |
| inputs_posi.update(eligen_kwargs_posi) | |
| if inputs_shared.get("cfg_scale", 1.0) != 1.0: | |
| inputs_nega.update(eligen_kwargs_nega) | |
| return inputs_shared, inputs_posi, inputs_nega | |
| class QwenImageUnit_BlockwiseControlNet(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"), | |
| output_params=("blockwise_controlnet_conditioning",), | |
| onload_model_names=("vae",) | |
| ) | |
| def apply_controlnet_mask_on_latents(self, pipe, latents, mask): | |
| mask = (pipe.preprocess_image(mask) + 1) / 2 | |
| mask = mask.mean(dim=1, keepdim=True) | |
| mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) | |
| latents = torch.concat([latents, mask], dim=1) | |
| return latents | |
| def apply_controlnet_mask_on_image(self, pipe, image, mask): | |
| mask = mask.resize(image.size) | |
| mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() | |
| image = np.array(image) | |
| image[mask > 0] = 0 | |
| image = Image.fromarray(image) | |
| return image | |
| def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): | |
| if blockwise_controlnet_inputs is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| conditionings = [] | |
| for controlnet_input in blockwise_controlnet_inputs: | |
| image = controlnet_input.image | |
| if controlnet_input.inpaint_mask is not None: | |
| image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) | |
| image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| if controlnet_input.inpaint_mask is not None: | |
| image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) | |
| conditionings.append(image) | |
| return {"blockwise_controlnet_conditioning": conditionings} | |
| class QwenImageUnit_EditImageEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"), | |
| output_params=("edit_latents", "edit_image"), | |
| onload_model_names=("vae",) | |
| ) | |
| def calculate_dimensions(self, target_area, ratio): | |
| import math | |
| width = math.sqrt(target_area * ratio) | |
| height = width / ratio | |
| width = round(width / 32) * 32 | |
| height = round(height / 32) * 32 | |
| return width, height | |
| def edit_image_auto_resize(self, edit_image): | |
| calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) | |
| return edit_image.resize((calculated_width, calculated_height)) | |
| def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False): | |
| if edit_image is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| if isinstance(edit_image, Image.Image): | |
| resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image | |
| edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| else: | |
| resized_edit_image, edit_latents = [], [] | |
| for image in edit_image: | |
| if edit_image_auto_resize: | |
| image = self.edit_image_auto_resize(image) | |
| resized_edit_image.append(image) | |
| image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| edit_latents.append(latents) | |
| return {"edit_latents": edit_latents, "edit_image": resized_edit_image} | |
| class QwenImageUnit_Image2LoRAEncode(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("image2lora_images",), | |
| output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), | |
| onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"), | |
| ) | |
| from ..core.data.operators import ImageCropAndResize | |
| self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) | |
| self.processor_highres = ImageCropAndResize(height=1024, width=1024) | |
| def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | |
| return split_result | |
| def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): | |
| prompt = [prompt] | |
| template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" | |
| drop_idx = 64 | |
| txt = [template.format(e) for e in prompt] | |
| model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) | |
| hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] | |
| split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) | |
| prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return prompt_embeds.view(1, -1) | |
| def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]): | |
| pipe.load_models_to_device(["siglip2_image_encoder"]) | |
| embs = [] | |
| for image in images: | |
| image = self.processor_highres(image) | |
| embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) | |
| embs = torch.stack(embs) | |
| return embs | |
| def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]): | |
| pipe.load_models_to_device(["dinov3_image_encoder"]) | |
| embs = [] | |
| for image in images: | |
| image = self.processor_highres(image) | |
| embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) | |
| embs = torch.stack(embs) | |
| return embs | |
| def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False): | |
| pipe.load_models_to_device(["text_encoder"]) | |
| embs = [] | |
| for image in images: | |
| image = self.processor_highres(image) if highres else self.processor_lowres(image) | |
| embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) | |
| embs = torch.stack(embs) | |
| return embs | |
| def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]): | |
| if images is None: | |
| return {} | |
| if not isinstance(images, list): | |
| images = [images] | |
| embs_siglip2 = self.encode_images_using_siglip2(pipe, images) | |
| embs_dinov3 = self.encode_images_using_dinov3(pipe, images) | |
| x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) | |
| residual = None | |
| residual_highres = None | |
| if pipe.image2lora_coarse is not None: | |
| residual = self.encode_images_using_qwenvl(pipe, images, highres=False) | |
| if pipe.image2lora_fine is not None: | |
| residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True) | |
| return x, residual, residual_highres | |
| def process(self, pipe: QwenImagePipeline, image2lora_images): | |
| if image2lora_images is None: | |
| return {} | |
| x, residual, residual_highres = self.encode_images(pipe, image2lora_images) | |
| return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} | |
| class QwenImageUnit_Image2LoRADecode(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), | |
| output_params=("lora",), | |
| onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"), | |
| ) | |
| def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): | |
| if image2lora_x is None: | |
| return {} | |
| loras = [] | |
| if pipe.image2lora_style is not None: | |
| pipe.load_models_to_device(["image2lora_style"]) | |
| for x in image2lora_x: | |
| loras.append(pipe.image2lora_style(x=x, residual=None)) | |
| if pipe.image2lora_coarse is not None: | |
| pipe.load_models_to_device(["image2lora_coarse"]) | |
| for x, residual in zip(image2lora_x, image2lora_residual): | |
| loras.append(pipe.image2lora_coarse(x=x, residual=residual)) | |
| if pipe.image2lora_fine is not None: | |
| pipe.load_models_to_device(["image2lora_fine"]) | |
| for x, residual in zip(image2lora_x, image2lora_residual_highres): | |
| loras.append(pipe.image2lora_fine(x=x, residual=residual)) | |
| lora = merge_lora(loras, alpha=1 / len(image2lora_x)) | |
| return {"lora": lora} | |
| class QwenImageUnit_ContextImageEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| output_params=("context_latents",), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride): | |
| if context_image is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype) | |
| context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| return {"context_latents": context_latents} | |
| def model_fn_qwen_image( | |
| dit: QwenImageDiT = None, | |
| blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None, | |
| latents=None, | |
| timestep=None, | |
| prompt_emb=None, | |
| prompt_emb_mask=None, | |
| height=None, | |
| width=None, | |
| blockwise_controlnet_conditioning=None, | |
| blockwise_controlnet_inputs=None, | |
| progress_id=0, | |
| num_inference_steps=1, | |
| entity_prompt_emb=None, | |
| entity_prompt_emb_mask=None, | |
| entity_masks=None, | |
| edit_latents=None, | |
| layer_input_latents=None, | |
| layer_num=None, | |
| context_latents=None, | |
| enable_fp8_attention=False, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| edit_rope_interpolation=False, | |
| zero_cond_t=False, | |
| **kwargs | |
| ): | |
| if layer_num is None: | |
| layer_num = 1 | |
| img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] | |
| else: | |
| layer_num = layer_num + 1 | |
| img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num | |
| txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() | |
| timestep = timestep / 1000 | |
| image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num) | |
| image_seq_len = image.shape[1] | |
| if context_latents is not None: | |
| img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)] | |
| context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2) | |
| image = torch.cat([image, context_image], dim=1) | |
| if edit_latents is not None: | |
| edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents] | |
| img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] | |
| edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] | |
| image = torch.cat([image] + edit_image, dim=1) | |
| if layer_input_latents is not None: | |
| layer_num = layer_num + 1 | |
| img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)] | |
| layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) | |
| image = torch.cat([image, layer_input_latents], dim=1) | |
| image = dit.img_in(image) | |
| if zero_cond_t: | |
| timestep = torch.cat([timestep, timestep * 0], dim=0) | |
| modulate_index = torch.tensor( | |
| [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]], | |
| device=timestep.device, | |
| dtype=torch.int, | |
| ) | |
| else: | |
| modulate_index = None | |
| conditioning = dit.time_text_embed( | |
| timestep, | |
| image.dtype, | |
| addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long) | |
| ) | |
| if entity_prompt_emb is not None: | |
| text, image_rotary_emb, attention_mask = dit.process_entity_masks( | |
| latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, | |
| entity_masks, height, width, image, img_shapes, | |
| ) | |
| else: | |
| text = dit.txt_in(dit.txt_norm(prompt_emb)) | |
| if edit_rope_interpolation: | |
| image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device) | |
| else: | |
| image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) | |
| attention_mask = None | |
| if blockwise_controlnet_conditioning is not None: | |
| blockwise_controlnet_conditioning = blockwise_controlnet.preprocess( | |
| blockwise_controlnet_inputs, blockwise_controlnet_conditioning) | |
| for block_id, block in enumerate(dit.transformer_blocks): | |
| text, image = gradient_checkpoint_forward( | |
| block, | |
| use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload, | |
| image=image, | |
| text=text, | |
| temb=conditioning, | |
| image_rotary_emb=image_rotary_emb, | |
| attention_mask=attention_mask, | |
| enable_fp8_attention=enable_fp8_attention, | |
| modulate_index=modulate_index, | |
| ) | |
| if blockwise_controlnet_conditioning is not None: | |
| image_slice = image[:, :image_seq_len].clone() | |
| controlnet_output = blockwise_controlnet.blockwise_forward( | |
| image=image_slice, conditionings=blockwise_controlnet_conditioning, | |
| controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id, | |
| progress_id=progress_id, num_inference_steps=num_inference_steps, | |
| ) | |
| image[:, :image_seq_len] = image_slice + controlnet_output | |
| if zero_cond_t: | |
| conditioning = conditioning.chunk(2, dim=0)[0] | |
| image = dit.norm_out(image, conditioning) | |
| image = dit.proj_out(image) | |
| image = image[:, :image_seq_len] | |
| latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1) | |
| return latents | |