| from diffusers import DiffusionPipeline |
| import torch |
| from diffusers.utils import BaseOutput |
| from dataclasses import dataclass |
| from typing import List, Union, Optional, Tuple |
| from PIL import Image |
| import numpy as np |
| from tqdm import tqdm |
|
|
| @dataclass |
| class SdxsPipelineOutput(BaseOutput): |
| images: Union[List[Image.Image], np.ndarray] |
|
|
| class SdxsPipeline(DiffusionPipeline): |
| def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192): |
| super().__init__() |
| self.register_modules( |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, |
| unet=unet, scheduler=scheduler |
| ) |
| self.vae_scale_factor = 16 |
| self.max_length = max_length |
|
|
| def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| device = device or self.device |
| dtype = dtype or next(self.unet.parameters()).dtype |
| |
| |
| if isinstance(prompt, str): |
| prompt = [prompt] |
| if isinstance(negative_prompt, str): |
| negative_prompt = [negative_prompt] |
| |
| |
| if prompt is None and negative_prompt is None: |
| hidden_dim = 1024 |
| seq_len = self.max_length |
| batch_size = 1 |
| |
| empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) |
| empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) |
| empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device) |
| return empty_embeds, empty_mask, empty_pooled |
| |
| |
| def encode_texts(texts, max_length=self.max_length): |
| with torch.no_grad(): |
| if isinstance(texts, str): |
| texts = [texts] |
| |
| for i, prompt_item in enumerate(texts): |
| messages = [ |
| {"role": "user", "content": prompt_item}, |
| ] |
| prompt_item = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=True, |
| ) |
| texts[i] = prompt_item |
| |
| toks = self.tokenizer( |
| texts, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=max_length |
| ).to(device) |
| outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True) |
| |
| |
| hidden = outs.hidden_states[-2] |
| |
| attention_mask = toks["attention_mask"] |
| |
| |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = hidden.shape[0] |
| pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] |
| |
| |
| |
| pooled_expanded = pooled.unsqueeze(1) |
| |
| |
| |
| |
| new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1) |
| |
| |
| |
| |
| new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1) |
| |
| return new_encoder_hidden_states, new_attention_mask, pooled |
| |
| |
| |
| pos_result = encode_texts(prompt) if prompt is not None else (None, None, None) |
| neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None) |
| |
| pos_embeddings, pos_mask, pos_pooled = pos_result |
| neg_embeddings, neg_mask, neg_pooled = neg_result |
| |
| |
| batch_size = max( |
| pos_embeddings.shape[0] if pos_embeddings is not None else 0, |
| neg_embeddings.shape[0] if neg_embeddings is not None else 0 |
| ) |
| |
| |
| if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size: |
| pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1) |
| pos_mask = pos_mask.repeat(batch_size, 1) |
| pos_pooled = pos_pooled.repeat(batch_size, 1) |
| |
| |
| if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size: |
| neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1) |
| neg_mask = neg_mask.repeat(batch_size, 1) |
| neg_pooled = neg_pooled.repeat(batch_size, 1) |
| |
| |
| |
| if pos_embeddings is not None and neg_embeddings is not None: |
| text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) |
| attention_mask = torch.cat([neg_mask, pos_mask], dim=0) |
| pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0) |
| elif pos_embeddings is not None: |
| text_embeddings = pos_embeddings |
| attention_mask = pos_mask |
| pooled_embeddings = pos_pooled |
| else: |
| text_embeddings = neg_embeddings |
| attention_mask = neg_mask |
| pooled_embeddings = neg_pooled |
| |
| |
| return ( |
| text_embeddings.to(device=device, dtype=dtype), |
| attention_mask.to(device=device, dtype=torch.int64), |
| pooled_embeddings.to(device=device, dtype=dtype) |
| ) |
|
|
|
|
| @torch.no_grad() |
| def generate_latents( |
| self, |
| text_embeddings, |
| attention_mask, |
| pooled_embeddings, |
| height: int = 1280, |
| width: int = 1024, |
| num_inference_steps: int = 40, |
| guidance_scale: float = 4.0, |
| latent_channels: int = 16, |
| batch_size: int = 1, |
| generator=None, |
| ): |
| device = self.device |
| dtype = next(self.unet.parameters()).dtype |
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| |
| |
| if guidance_scale > 1: |
| neg_embeds, pos_embeds = text_embeddings.chunk(2) |
| neg_mask, pos_mask = attention_mask.chunk(2) |
| neg_pooled, pos_pooled = pooled_embeddings.chunk(2) |
| |
| |
| if batch_size > pos_embeds.shape[0]: |
| pos_embeds = pos_embeds.repeat(batch_size, 1, 1) |
| neg_embeds = neg_embeds.repeat(batch_size, 1, 1) |
| pos_mask = pos_mask.repeat(batch_size, 1) |
| neg_mask = neg_mask.repeat(batch_size, 1) |
| pos_pooled = pos_pooled.repeat(batch_size, 1) |
| neg_pooled = neg_pooled.repeat(batch_size, 1) |
| |
| text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) |
| unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0) |
| unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0) |
| else: |
| text_embeddings = text_embeddings.repeat(batch_size, 1, 1) |
| unet_attention_mask = attention_mask.repeat(batch_size, 1) |
| unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1) |
| |
| |
| latent_shape = ( |
| batch_size, |
| latent_channels, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor |
| ) |
| latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator) |
| |
| |
| for t in tqdm(self.scheduler.timesteps, desc="Генерация"): |
| latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents |
| |
| noise_pred = self.unet( |
| latent_input, |
| t, |
| encoder_hidden_states=text_embeddings, |
| encoder_attention_mask=unet_attention_mask, |
| |
| ).sample |
| |
| if guidance_scale > 1: |
| noise_uncond, noise_text = noise_pred.chunk(2) |
| noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) |
| |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
| |
| return latents |
|
|
|
|
| def decode_latents(self, latents, output_type="pil"): |
| """Декодирование латентов в изображения.""" |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
| with torch.no_grad(): |
| images = self.vae.decode(latents).sample |
| images = (images / 2 + 0.5).clamp(0, 1) |
|
|
| if output_type == "pil": |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| images = (images * 255).round().astype("uint8") |
| return [Image.fromarray(image) for image in images] |
| return images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Optional[Union[str, List[str]]] = None, |
| height: int = 1280, |
| width: int = 1024, |
| num_inference_steps: int = 40, |
| guidance_scale: float = 4.0, |
| latent_channels: int = 16, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| batch_size: int = 1, |
| seed: Optional[int] = None, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| text_embeddings: Optional[torch.FloatTensor] = None, |
| ): |
| device = self.device |
| generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None |
|
|
| if text_embeddings is None: |
| if prompt is None and negative_prompt is None: |
| raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings") |
| |
| text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt( |
| prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype |
| ) |
| else: |
| |
| |
| |
| raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.") |
|
|
|
|
| latents = self.generate_latents( |
| text_embeddings=text_embeddings, |
| attention_mask=attention_mask, |
| pooled_embeddings=pooled_embeddings, |
| height=height, |
| width=width, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| latent_channels=latent_channels, |
| batch_size=batch_size, |
| generator=generator |
| ) |
|
|
| images = self.decode_latents(latents, output_type=output_type) |
| if not return_dict: |
| return images |
| return SdxsPipelineOutput(images=images) |