sdxs / src /pipeline_sdxs_no_pooling.py
recoilme's picture
2511
f08d8ce
from diffusers import DiffusionPipeline
import torch
from diffusers.utils import BaseOutput
from dataclasses import dataclass
from typing import List, Union, Optional
from PIL import Image
import numpy as np
from tqdm import tqdm
@dataclass
class SdxsPipelineOutput(BaseOutput):
images: Union[List[Image.Image], np.ndarray]
class SdxsPipeline(DiffusionPipeline):
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 150):
super().__init__()
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler
)
self.vae_scale_factor = 16
self.max_length = max_length
def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
device = device or self.device
dtype = dtype or next(self.unet.parameters()).dtype
# Преобразуем в списки
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
# Если промпты не заданы, используем пустые эмбеддинги
if prompt is None and negative_prompt is None:
hidden_dim = 1024 # Размерность эмбеддинга Qwen3-0.6B
seq_len = self.max_length
batch_size = 1
# ИЗМЕНЕНИЕ 1: Для пустых эмбеддингов возвращаем также маску (единицы)
empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
return empty_embeds, empty_mask
# Токенизация с фиксированным max_length=150 и padding="max_length"
def encode_texts(texts, max_length=self.max_length):
with torch.no_grad():
toks = self.tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length
).to(device)
outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
# ИЗМЕНЕНИЕ 2: Используем outs.last_hidden_state (нормализованный)
hidden = outs.hidden_states[-1] #hidden = outs.last_hidden_state
# ИЗМЕНЕНИЕ 3: Получаем 2D маску
mask_for_unet = toks["attention_mask"]
# ИЗМЕНЕНИЕ 4: Удален zero-padding эмбеддингов, так как UNet использует mask_for_unet
return hidden, mask_for_unet # ИЗМЕНЕНИЕ 5: Возвращаем и эмбеддинг, и маску
# Кодируем позитивные и негативные промпты
pos_result = encode_texts(prompt) if prompt is not None else (None, None)
neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None)
pos_embeddings, pos_mask = pos_result
neg_embeddings, neg_mask = neg_result
# Выравниваем размеры batch_size
batch_size = max(
pos_embeddings.shape[0] if pos_embeddings is not None else 0,
neg_embeddings.shape[0] if neg_embeddings is not None else 0
)
# Повторяем эмбеддинги и маски по batch_size
if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
pos_mask = pos_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 6: Повторяем маску
if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
neg_mask = neg_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 7: Повторяем маску
# Конкатенируем для guidance (эмбеддинги и маски)
if pos_embeddings is not None and neg_embeddings is not None:
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
attention_mask = torch.cat([neg_mask, pos_mask], dim=0) # ИЗМЕНЕНИЕ 8: Конкатенируем маску
elif pos_embeddings is not None:
text_embeddings = pos_embeddings
attention_mask = pos_mask
else:
text_embeddings = neg_embeddings
attention_mask = neg_mask
# ИЗМЕНЕНИЕ 9: Возвращаем кортеж
return (
text_embeddings.to(device=device, dtype=dtype),
attention_mask.to(device=device, dtype=torch.int64)
)
@torch.no_grad()
def generate_latents(
self,
text_embeddings,
attention_mask, # ИЗМЕНЕНИЕ 10: Принимаем маску
height: int = 1280,
width: int = 1024,
num_inference_steps: int = 40,
guidance_scale: float = 4.0,
latent_channels: int = 16,
batch_size: int = 1,
generator=None,
):
device = self.device
dtype = next(self.unet.parameters()).dtype
self.scheduler.set_timesteps(num_inference_steps, device=device)
# Разделяем эмбеддинги и маски на условные и безусловные
if guidance_scale > 1:
neg_embeds, pos_embeds = text_embeddings.chunk(2)
neg_mask, pos_mask = attention_mask.chunk(2) # ИЗМЕНЕНИЕ 11: Разделяем маски
# Повторяем, если batch_size больше
if batch_size > pos_embeds.shape[0]:
pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
pos_mask = pos_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 12: Повторяем маски
neg_mask = neg_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 12: Повторяем маски
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0) # ИЗМЕНЕНИЕ 13: Конкатенируем маски для UNet
else:
text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
unet_attention_mask = attention_mask.repeat(batch_size, 1) # ИЗМЕНЕНИЕ 14: Повторяем маску
# Инициализация латентов
latent_shape = (
batch_size,
latent_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor
)
latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
# Процесс диффузии
for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
noise_pred = self.unet(
latent_input,
t,
encoder_hidden_states=text_embeddings,
encoder_attention_mask=unet_attention_mask # ИЗМЕНЕНИЕ 15: Передаем маску в UNet
).sample
if guidance_scale > 1:
noise_uncond, noise_text = noise_pred.chunk(2)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents, output_type="pil"):
"""Декодирование латентов в изображения."""
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
with torch.no_grad():
images = self.vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
if output_type == "pil":
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
return [Image.fromarray(image) for image in images]
return images.cpu().permute(0, 2, 3, 1).float().numpy()
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: int = 1280,
width: int = 1024,
num_inference_steps: int = 40,
guidance_scale: float = 4.0,
latent_channels: int = 16,
output_type: str = "pil",
return_dict: bool = True,
batch_size: int = 1,
seed: Optional[int] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
text_embeddings: Optional[torch.FloatTensor] = None,
):
device = self.device
generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
if text_embeddings is None:
if prompt is None and negative_prompt is None:
raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
# ИЗМЕНЕНИЕ 16: Получаем маску вместе с эмбеддингами
text_embeddings, attention_mask = self.encode_prompt(
prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype
)
# text_embeddings уже имеет структуру [B_uncond + B_cond, seq_len, hid], dtype и device совместимы
latents = self.generate_latents(
text_embeddings=text_embeddings,
attention_mask=attention_mask, # ИЗМЕНЕНИЕ 17: Передаем маску в generate_latents
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
latent_channels=latent_channels,
batch_size=batch_size,
generator=generator
)
images = self.decode_latents(latents, output_type=output_type)
if not return_dict:
return images
return SdxsPipelineOutput(images=images)