Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import torch | |
| import math | |
| from PIL import Image | |
| from typing import List, Optional | |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor | |
| from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, BitsAndBytesConfig | |
| from unipicv2.pipeline_stable_diffusion_3_kontext import StableDiffusion3KontextPipeline | |
| from unipicv2.transformer_sd3_kontext import SD3Transformer2DKontextModel | |
| from unipicv2.stable_diffusion_3_conditioner import StableDiffusion3Conditioner | |
| import spaces | |
| class UniPicV2Inferencer: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| qwen_vl_path: str, | |
| quant: str = "fp16", # {"int4", "fp16"} | |
| image_size: int = 512, | |
| default_negative_prompt: str = "blurry, low quality, low resolution, distorted, deformed, broken content, missing parts, damaged details, artifacts, glitch, noise, pixelated, grainy, compression artifacts, bad composition, wrong proportion, incomplete editing, unfinished, unedited areas." | |
| ): | |
| self.model_path = model_path | |
| self.qwen_vl_path = qwen_vl_path | |
| self.quant = quant | |
| self.image_size = image_size | |
| self.default_negative_prompt = default_negative_prompt | |
| self.device = torch.device("cuda") | |
| self.pipeline = None #self._init_pipeline() | |
| def _init_pipeline(self) -> StableDiffusion3KontextPipeline: | |
| print("Initializing UniPicV2 pipeline...") | |
| # ===== 1. Initialize BNB Config ===== | |
| bnb4 = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| # ===== 2. Load SD3 Transformer ===== | |
| if self.quant == "int4": | |
| transformer = SD3Transformer2DKontextModel.from_pretrained( | |
| self.model_path, subfolder="transformer", | |
| quantization_config=bnb4, device_map="auto", low_cpu_mem_usage=True | |
| ) | |
| else: | |
| transformer = SD3Transformer2DKontextModel.from_pretrained( | |
| self.model_path, subfolder="transformer", | |
| torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True | |
| ) | |
| # ===== 3. Load VAE ===== | |
| vae = AutoencoderKL.from_pretrained( | |
| self.model_path, subfolder="vae", | |
| torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True | |
| ).to(self.device) | |
| # ===== 4. Load Qwen2.5-VL (LMM) ===== | |
| try: | |
| self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| self.qwen_vl_path, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| device_map="auto", | |
| ).to(self.device) | |
| print("**"*20) | |
| except Exception: | |
| self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| self.qwen_vl_path, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| device_map="auto", | |
| ).to(self.device) | |
| # ===== 5. Load Processor ===== | |
| self.processor = Qwen2_5_VLProcessor.from_pretrained(self.qwen_vl_path, use_fast=False) | |
| if hasattr(self.processor, "chat_template") and self.processor.chat_template: | |
| self.processor.chat_template = self.processor.chat_template.replace( | |
| "{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}", | |
| "" | |
| ) | |
| # ===== 6. Load Conditioner ===== | |
| self.conditioner = StableDiffusion3Conditioner.from_pretrained( | |
| self.model_path, subfolder="conditioner", | |
| torch_dtype=torch.bfloat16, low_cpu_mem_usage=True | |
| ).to(self.device) | |
| # ===== 7. Load Scheduler ===== | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| self.model_path, subfolder="scheduler" | |
| ) | |
| # ===== 8. Create Pipeline ===== | |
| pipeline = StableDiffusion3KontextPipeline( | |
| transformer=transformer, | |
| vae=vae, | |
| text_encoder=None, | |
| tokenizer=None, | |
| text_encoder_2=None, | |
| tokenizer_2=None, | |
| text_encoder_3=None, | |
| tokenizer_3=None, | |
| scheduler=scheduler | |
| ) | |
| try: | |
| pipeline.enable_vae_slicing() | |
| pipeline.enable_vae_tiling() | |
| pipeline.enable_model_cpu_offload() | |
| except Exception: | |
| print("Note: Could not enable all memory-saving features") | |
| print("Pipeline initialization complete!") | |
| return pipeline | |
| def _prepare_text_inputs(self, prompt: str, negative_prompt: str = None): | |
| messages = [ | |
| [{"role": "user", "content": [{"type": "text", "text": prompt}]}], | |
| [{"role": "user", "content": [{"type": "text", "text": negative_prompt}]}] | |
| ] | |
| texts = [ | |
| self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
| for m in messages | |
| ] | |
| inputs = self.processor( | |
| text=texts, | |
| images=None, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| return inputs | |
| def _prepare_image_inputs(self, image: Image.Image, prompt: str, negative_prompt: str = None): | |
| negative_prompt = negative_prompt or self.default_negative_prompt | |
| messages = [ | |
| [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}], | |
| [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": negative_prompt}]}] | |
| ] | |
| texts = [ | |
| self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
| for m in messages | |
| ] | |
| min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) | |
| inputs = self.processor( | |
| text=texts, | |
| images=[image] * 2, | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| return inputs | |
| def _process_inputs(self, inputs: dict, num_queries: int): | |
| # Ensure all tensors are on the correct device | |
| inputs = { | |
| k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
| for k, v in inputs.items() | |
| } | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| # Pad with meta queries | |
| pad_ids = torch.zeros((input_ids.size(0), num_queries), | |
| dtype=input_ids.dtype, device=self.device) | |
| pad_mask = torch.ones((attention_mask.size(0), num_queries), | |
| dtype=attention_mask.dtype, device=self.device) | |
| input_ids = torch.cat([input_ids, pad_ids], dim=1) | |
| attention_mask = torch.cat([attention_mask, pad_mask], dim=1) | |
| # Get input embeddings | |
| # 获取 embedding 权重所在设备 | |
| embed_device = self.lmm.get_input_embeddings().weight.device | |
| # 确保 input_ids 在同一设备 | |
| input_ids = input_ids.to(embed_device) | |
| inputs_embeds = self.lmm.get_input_embeddings()(input_ids) | |
| # Ensure meta queries are on correct device | |
| self.conditioner.meta_queries.data = self.conditioner.meta_queries.data.to(self.device) | |
| inputs_embeds[:, -num_queries:] = self.conditioner.meta_queries[None].expand(2, -1, -1) | |
| # Handle image embeddings if present | |
| if "pixel_values" in inputs: | |
| image_embeds = self.lmm.visual( | |
| inputs["pixel_values"].to(self.device), | |
| grid_thw=inputs["image_grid_thw"].to(self.device) | |
| ) | |
| image_token_id = self.processor.tokenizer.convert_tokens_to_ids('<|image_pad|>') | |
| mask_img = (input_ids == image_token_id) | |
| inputs_embeds[mask_img] = image_embeds | |
| # Forward through LMM | |
| if hasattr(self.lmm.model, "rope_deltas"): | |
| self.lmm.model.rope_deltas = None | |
| #model_device = self.lmm.model.embed_tokens.weight.device | |
| # 强制将所有 tensor 输入搬到这个设备 | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.to(self.device) | |
| outputs = self.lmm.model( | |
| inputs_embeds=inputs_embeds.to(self.device), | |
| attention_mask=attention_mask.to(self.device), | |
| image_grid_thw=inputs.get("image_grid_thw", None), | |
| use_cache=False | |
| ) | |
| hidden_states = outputs.last_hidden_state[:, -num_queries:] | |
| hidden_states = hidden_states.to(self.device) | |
| # Get prompt embeds | |
| prompt_embeds, pooled_prompt_embeds = self.conditioner(hidden_states) | |
| return { | |
| "prompt_embeds": prompt_embeds[:1], | |
| "pooled_prompt_embeds": pooled_prompt_embeds[:1], | |
| "negative_prompt_embeds": prompt_embeds[1:], | |
| "negative_pooled_prompt_embeds": pooled_prompt_embeds[1:] | |
| } | |
| def _resize_image(self, image: Image.Image, size: int) -> Image.Image: | |
| w, h = image.size | |
| if w >= h: | |
| new_w = size | |
| new_h = int(h * (new_w / w)) | |
| new_h = (new_h // 32) * 32 | |
| else: | |
| new_h = size | |
| new_w = int(w * (new_h / h)) | |
| new_w = (new_w // 32) * 32 | |
| return image.resize((new_w, new_h)) | |
| def generate_image( | |
| self, | |
| prompt: str, | |
| negative_prompt: Optional[str] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 36, | |
| guidance_scale: float = 3.0, | |
| seed: int = 42 | |
| ) -> Image.Image: | |
| if not self.pipeline: | |
| self.pipeline = self._init_pipeline() | |
| height = height or self.image_size | |
| width = width or self.image_size | |
| prompt = "Generate an image: " + prompt | |
| negative_prompt = "Generate an image: " + negative_prompt if negative_prompt else "" #self.default_negative_prompt | |
| inputs = self._prepare_text_inputs(prompt, negative_prompt) | |
| num_queries = self.conditioner.config.num_queries | |
| embeds = self._process_inputs(inputs, num_queries) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| image = self.pipeline( | |
| prompt_embeds=embeds["prompt_embeds"].to(self.device), | |
| pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), | |
| negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), | |
| negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ).images | |
| return image | |
| def edit_image( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| negative_prompt: Optional[str] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 36, | |
| guidance_scale: float = 3.0, | |
| seed: int = 42 | |
| ) -> Image.Image: | |
| if image.mode in ["RGBA", "LA"] or image.mode.startswith("A"): | |
| image = image.convert("RGB") | |
| if not self.pipeline: | |
| self.pipeline = self._init_pipeline() | |
| original_size = image.size | |
| image = self._resize_image(image, self.image_size) | |
| if height is None or width is None: | |
| height, width = image.height, image.width | |
| inputs = self._prepare_image_inputs(image, prompt, negative_prompt) | |
| num_queries = self.conditioner.config.num_queries | |
| embeds = self._process_inputs(inputs, num_queries) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| latents = torch.randn( | |
| 1, self.pipeline.transformer.config.in_channels, | |
| height // self.pipeline.vae_scale_factor, | |
| width // self.pipeline.vae_scale_factor, | |
| device=self.device, dtype=self.pipeline.transformer.dtype) | |
| edited_image = self.pipeline( | |
| latents=latents, | |
| image=image, | |
| prompt_embeds=embeds["prompt_embeds"].to(self.device), | |
| pooled_prompt_embeds=embeds["pooled_prompt_embeds"].to(self.device), | |
| negative_prompt_embeds=embeds["negative_prompt_embeds"].to(self.device), | |
| negative_pooled_prompt_embeds=embeds["negative_pooled_prompt_embeds"].to(self.device), | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ).images | |
| return edited_image | |
| def understand_image( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| max_new_tokens: int = 512 | |
| ) -> str: | |
| """ | |
| Understand the content of an image and answer questions about it. | |
| Args: | |
| image: Input image to understand | |
| prompt: Question or instruction about the image | |
| max_new_tokens: Maximum number of tokens to generate | |
| Returns: | |
| str: The model's response to the prompt | |
| """ | |
| # Prepare messages in Qwen-VL format | |
| if not self.pipeline: | |
| self.pipeline = self._init_pipeline() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| # Apply chat template | |
| text = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Calculate appropriate image size for processing | |
| min_pixels = max_pixels = int(image.height * 28 / 32 * image.width * 28 / 32) | |
| # Process inputs | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image], | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels, | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate response | |
| generated_ids = self.lmm.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| # Trim input tokens from output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| # Decode the response | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text | |