sdxs / src /pipeline_sdxs-Copy1.py
recoilme's picture
768
0e9e9bc
from diffusers import DiffusionPipeline
import torch
from diffusers.utils import BaseOutput
from dataclasses import dataclass
from typing import List, Union, Optional, Tuple
from PIL import Image
import numpy as np
from tqdm import tqdm
@dataclass
class SdxsPipelineOutput(BaseOutput):
images: Union[List[Image.Image], np.ndarray]
class SdxsPipeline(DiffusionPipeline):
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192):
super().__init__()
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler
)
self.vae_scale_factor = 16
self.max_length = max_length
def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = device or self.device
dtype = dtype or next(self.unet.parameters()).dtype
# Преобразуем в списки
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
# Если промпты не заданы, используем пустые эмбеддинги
if prompt is None and negative_prompt is None:
hidden_dim = 1024 # Размерность эмбеддинга
seq_len = self.max_length
batch_size = 1
# ИЗМЕНЕНО: Возвращаем три элемента: embeds, mask, pooled
empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device)
return empty_embeds, empty_mask, empty_pooled
# Токенизация с фиксированным max_length и padding="max_length"
def encode_texts(texts, max_length=self.max_length):
with torch.no_grad():
if isinstance(texts, str):
texts = [texts]
for i, prompt_item in enumerate(texts):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
texts[i] = prompt_item
toks = self.tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length
).to(device)
outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
# Токен-эмбеддинги (для Cross-Attention)
hidden = outs.hidden_states[-2] # Используем last hidden state -2???
# Маска внимания (для Cross-Attention)
attention_mask = toks["attention_mask"]
# Пулинг-эмбеддинг (для Class/Time Conditioning). Берем эмбеддинг последнего токена без padding.
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = hidden.shape[0]
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
# --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
# 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
pooled_expanded = pooled.unsqueeze(1)
# 2. Объединяем последовательность токенов и пулинг-вектор
# !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
# Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
# 3. Обновляем маску внимания для нового токена
# Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
return new_encoder_hidden_states, new_attention_mask, pooled
# Кодируем позитивные и негативные промпты
# ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
pos_result = encode_texts(prompt) if prompt is not None else (None, None, None)
neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None)
pos_embeddings, pos_mask, pos_pooled = pos_result
neg_embeddings, neg_mask, neg_pooled = neg_result
# Выравниваем размеры batch_size
batch_size = max(
pos_embeddings.shape[0] if pos_embeddings is not None else 0,
neg_embeddings.shape[0] if neg_embeddings is not None else 0
)
# Повторяем эмбеддинги, маски и пулинг по batch_size
if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
pos_mask = pos_mask.repeat(batch_size, 1)
pos_pooled = pos_pooled.repeat(batch_size, 1)
# ИСПРАВЛЕНИЕ: Проверяем, существует ли neg_embeddings, прежде чем обращаться к его shape[0]
if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
neg_mask = neg_mask.repeat(batch_size, 1)
neg_pooled = neg_pooled.repeat(batch_size, 1)
# Конкатенируем для guidance (эмбеддинги и маски)
# Убеждаемся, что все три компонента существуют перед конкатенацией
if pos_embeddings is not None and neg_embeddings is not None:
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
elif pos_embeddings is not None:
text_embeddings = pos_embeddings
attention_mask = pos_mask
pooled_embeddings = pos_pooled
else: # Только neg_embeddings
text_embeddings = neg_embeddings
attention_mask = neg_mask
pooled_embeddings = neg_pooled
# Возвращаем кортеж
return (
text_embeddings.to(device=device, dtype=dtype),
attention_mask.to(device=device, dtype=torch.int64),
pooled_embeddings.to(device=device, dtype=dtype)
)
@torch.no_grad()
def generate_latents(
self,
text_embeddings,
attention_mask,
pooled_embeddings,
height: int = 1280,
width: int = 1024,
num_inference_steps: int = 40,
guidance_scale: float = 4.0,
latent_channels: int = 16,
batch_size: int = 1,
generator=None,
):
device = self.device
dtype = next(self.unet.parameters()).dtype
self.scheduler.set_timesteps(num_inference_steps, device=device)
# Разделяем эмбеддинги и маски на условные и безусловные
if guidance_scale > 1:
neg_embeds, pos_embeds = text_embeddings.chunk(2)
neg_mask, pos_mask = attention_mask.chunk(2)
neg_pooled, pos_pooled = pooled_embeddings.chunk(2)
# Повторяем, если batch_size больше
if batch_size > pos_embeds.shape[0]:
pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
pos_mask = pos_mask.repeat(batch_size, 1)
neg_mask = neg_mask.repeat(batch_size, 1)
pos_pooled = pos_pooled.repeat(batch_size, 1)
neg_pooled = neg_pooled.repeat(batch_size, 1)
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
else:
text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
unet_attention_mask = attention_mask.repeat(batch_size, 1)
unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1)
# Инициализация латентов
latent_shape = (
batch_size,
latent_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor
)
latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
# Процесс диффузии
for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
noise_pred = self.unet(
latent_input,
t,
encoder_hidden_states=text_embeddings,
encoder_attention_mask=unet_attention_mask,
#added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
).sample
if guidance_scale > 1:
noise_uncond, noise_text = noise_pred.chunk(2)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents, output_type="pil"):
"""Декодирование латентов в изображения."""
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
with torch.no_grad():
images = self.vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
if output_type == "pil":
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
return [Image.fromarray(image) for image in images]
return images.cpu().permute(0, 2, 3, 1).float().numpy()
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: int = 1280,
width: int = 1024,
num_inference_steps: int = 40,
guidance_scale: float = 4.0,
latent_channels: int = 16,
output_type: str = "pil",
return_dict: bool = True,
batch_size: int = 1,
seed: Optional[int] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
text_embeddings: Optional[torch.FloatTensor] = None,
):
device = self.device
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
if text_embeddings is None:
if prompt is None and negative_prompt is None:
raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt(
prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype
)
else:
# Требуется, чтобы внешний text_embeddings содержал объединенные cond/uncond,
# но мы не можем получить attention_mask и pooled_embeddings.
# Для простоты лучше требовать prompt/negative_prompt.
raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.")
latents = self.generate_latents(
text_embeddings=text_embeddings,
attention_mask=attention_mask,
pooled_embeddings=pooled_embeddings,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
latent_channels=latent_channels,
batch_size=batch_size,
generator=generator
)
images = self.decode_latents(latents, output_type=output_type)
if not return_dict:
return images
return SdxsPipelineOutput(images=images)