from diffusers import DiffusionPipeline import torch from diffusers.utils import BaseOutput from dataclasses import dataclass from typing import List, Union, Optional from PIL import Image import numpy as np from tqdm import tqdm @dataclass class SdxsPipelineOutput(BaseOutput): images: Union[List[Image.Image], np.ndarray] class SdxsPipeline(DiffusionPipeline): def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 150): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler ) self.vae_scale_factor = 16 self.max_length = max_length def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None): device = device or self.device dtype = dtype or next(self.unet.parameters()).dtype # Преобразуем в списки if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Если промпты не заданы, используем пустые эмбеддинги if prompt is None and negative_prompt is None: hidden_dim = 1024 # Размерность эмбеддинга Qwen3-0.6B seq_len = self.max_length batch_size = 1 # ИЗМЕНЕНИЕ 1: Для пустых эмбеддингов возвращаем также маску (единицы) empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) return empty_embeds, empty_mask # Токенизация с фиксированным max_length=150 и padding="max_length" def encode_texts(texts, max_length=self.max_length): with torch.no_grad(): toks = self.tokenizer( texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length ).to(device) outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True) # ИЗМЕНЕНИЕ 2: Используем outs.last_hidden_state (нормализованный) hidden = outs.hidden_states[-1] #hidden = outs.last_hidden_state # ИЗМЕНЕНИЕ 3: Получаем 2D маску mask_for_unet = toks["attention_mask"] # ИЗМЕНЕНИЕ 4: Удален zero-padding эмбеддингов, так как UNet использует mask_for_unet return hidden, mask_for_unet # ИЗМЕНЕНИЕ 5: Возвращаем и эмбеддинг, и маску # Кодируем позитивные и негативные промпты pos_result = encode_texts(prompt) if prompt is not None else (None, None) neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None) pos_embeddings, pos_mask = pos_result neg_embeddings, neg_mask = neg_result # Выравниваем размеры batch_size batch_size = max( pos_embeddings.shape[0] if pos_embeddings is not None else 0, neg_embeddings.shape[0] if neg_embeddings is not None else 0 ) # Повторяем эмбеддинги и маски по batch_size if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size: pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1) pos_mask = pos_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 6: Повторяем маску if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size: neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1) neg_mask = neg_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 7: Повторяем маску # Конкатенируем для guidance (эмбеддинги и маски) if pos_embeddings is not None and neg_embeddings is not None: text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) attention_mask = torch.cat([neg_mask, pos_mask], dim=0) # ИЗМЕНЕНИЕ 8: Конкатенируем маску elif pos_embeddings is not None: text_embeddings = pos_embeddings attention_mask = pos_mask else: text_embeddings = neg_embeddings attention_mask = neg_mask # ИЗМЕНЕНИЕ 9: Возвращаем кортеж return ( text_embeddings.to(device=device, dtype=dtype), attention_mask.to(device=device, dtype=torch.int64) ) @torch.no_grad() def generate_latents( self, text_embeddings, attention_mask, # ИЗМЕНЕНИЕ 10: Принимаем маску height: int = 1280, width: int = 1024, num_inference_steps: int = 40, guidance_scale: float = 4.0, latent_channels: int = 16, batch_size: int = 1, generator=None, ): device = self.device dtype = next(self.unet.parameters()).dtype self.scheduler.set_timesteps(num_inference_steps, device=device) # Разделяем эмбеддинги и маски на условные и безусловные if guidance_scale > 1: neg_embeds, pos_embeds = text_embeddings.chunk(2) neg_mask, pos_mask = attention_mask.chunk(2) # ИЗМЕНЕНИЕ 11: Разделяем маски # Повторяем, если batch_size больше if batch_size > pos_embeds.shape[0]: pos_embeds = pos_embeds.repeat(batch_size, 1, 1) neg_embeds = neg_embeds.repeat(batch_size, 1, 1) pos_mask = pos_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 12: Повторяем маски neg_mask = neg_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 12: Повторяем маски text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0) # ИЗМЕНЕНИЕ 13: Конкатенируем маски для UNet else: text_embeddings = text_embeddings.repeat(batch_size, 1, 1) unet_attention_mask = attention_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 14: Повторяем маску # Инициализация латентов latent_shape = ( batch_size, latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor ) latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator) # Процесс диффузии for t in tqdm(self.scheduler.timesteps, desc="Генерация"): latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents noise_pred = self.unet( latent_input, t, encoder_hidden_states=text_embeddings, encoder_attention_mask=unet_attention_mask # ИЗМЕНЕНИЕ 15: Передаем маску в UNet ).sample if guidance_scale > 1: noise_uncond, noise_text = noise_pred.chunk(2) noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents, output_type="pil"): """Декодирование латентов в изображения.""" latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor with torch.no_grad(): images = self.vae.decode(latents).sample images = (images / 2 + 0.5).clamp(0, 1) if output_type == "pil": images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = (images * 255).round().astype("uint8") return [Image.fromarray(image) for image in images] return images.cpu().permute(0, 2, 3, 1).float().numpy() @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, height: int = 1280, width: int = 1024, num_inference_steps: int = 40, guidance_scale: float = 4.0, latent_channels: int = 16, output_type: str = "pil", return_dict: bool = True, batch_size: int = 1, seed: Optional[int] = None, negative_prompt: Optional[Union[str, List[str]]] = None, text_embeddings: Optional[torch.FloatTensor] = None, ): device = self.device generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None if text_embeddings is None: if prompt is None and negative_prompt is None: raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings") # ИЗМЕНЕНИЕ 16: Получаем маску вместе с эмбеддингами text_embeddings, attention_mask = self.encode_prompt( prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype ) # text_embeddings уже имеет структуру [B_uncond + B_cond, seq_len, hid], dtype и device совместимы latents = self.generate_latents( text_embeddings=text_embeddings, attention_mask=attention_mask, # ИЗМЕНЕНИЕ 17: Передаем маску в generate_latents height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, latent_channels=latent_channels, batch_size=batch_size, generator=generator ) images = self.decode_latents(latents, output_type=output_type) if not return_dict: return images return SdxsPipelineOutput(images=images)