| from typing import TYPE_CHECKING, Optional |
| from pydash import flatten |
|
|
| import torch |
| from transformers.models.clip.tokenization_clip import CLIPTokenizer |
| from einops import repeat |
|
|
| if TYPE_CHECKING: |
| from flux_pipeline import FluxPipeline |
|
|
|
|
| def parse_prompt_attention(text): |
| """ |
| Parses a string with attention tokens and returns a list of pairs: text and its associated weight. |
| Accepted tokens are: |
| (abc) - increases attention to abc by a multiplier of 1.1 |
| (abc:3.12) - increases attention to abc by a multiplier of 3.12 |
| [abc] - decreases attention to abc by a multiplier of 1.1 |
| \\( - literal character '(' |
| \\[ - literal character '[' |
| \\) - literal character ')' |
| \\] - literal character ']' |
| \\ - literal character '\' |
| anything else - just text |
| |
| >>> parse_prompt_attention('normal text') |
| [['normal text', 1.0]] |
| >>> parse_prompt_attention('an (important) word') |
| [['an ', 1.0], ['important', 1.1], [' word', 1.0]] |
| >>> parse_prompt_attention('(unbalanced') |
| [['unbalanced', 1.1]] |
| >>> parse_prompt_attention('\\(literal\\]') |
| [['(literal]', 1.0]] |
| >>> parse_prompt_attention('(unnecessary)(parens)') |
| [['unnecessaryparens', 1.1]] |
| >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') |
| [['a ', 1.0], |
| ['house', 1.5730000000000004], |
| [' ', 1.1], |
| ['on', 1.0], |
| [' a ', 1.1], |
| ['hill', 0.55], |
| [', sun, ', 1.1], |
| ['sky', 1.4641000000000006], |
| ['.', 1.1]] |
| """ |
| import re |
|
|
| re_attention = re.compile( |
| r""" |
| \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)| |
| \)|]|[^\\()\[\]:]+|: |
| """, |
| re.X, |
| ) |
|
|
| re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) |
|
|
| res = [] |
| round_brackets = [] |
| square_brackets = [] |
|
|
| round_bracket_multiplier = 1.1 |
| square_bracket_multiplier = 1 / 1.1 |
|
|
| def multiply_range(start_position, multiplier): |
| for p in range(start_position, len(res)): |
| res[p][1] *= multiplier |
|
|
| for m in re_attention.finditer(text): |
| text = m.group(0) |
| weight = m.group(1) |
|
|
| if text.startswith("\\"): |
| res.append([text[1:], 1.0]) |
| elif text == "(": |
| round_brackets.append(len(res)) |
| elif text == "[": |
| square_brackets.append(len(res)) |
| elif weight is not None and len(round_brackets) > 0: |
| multiply_range(round_brackets.pop(), float(weight)) |
| elif text == ")" and len(round_brackets) > 0: |
| multiply_range(round_brackets.pop(), round_bracket_multiplier) |
| elif text == "]" and len(square_brackets) > 0: |
| multiply_range(square_brackets.pop(), square_bracket_multiplier) |
| else: |
| parts = re.split(re_break, text) |
| for i, part in enumerate(parts): |
| if i > 0: |
| res.append(["BREAK", -1]) |
| res.append([part, 1.0]) |
|
|
| for pos in round_brackets: |
| multiply_range(pos, round_bracket_multiplier) |
|
|
| for pos in square_brackets: |
| multiply_range(pos, square_bracket_multiplier) |
|
|
| if len(res) == 0: |
| res = [["", 1.0]] |
|
|
| |
| i = 0 |
| while i + 1 < len(res): |
| if res[i][1] == res[i + 1][1]: |
| res[i][0] += res[i + 1][0] |
| res.pop(i + 1) |
| else: |
| i += 1 |
|
|
| return res |
|
|
|
|
| def get_prompts_tokens_with_weights( |
| clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False |
| ): |
| """ |
| Get prompt token ids and weights, this function works for both prompt and negative prompt |
| |
| Args: |
| pipe (CLIPTokenizer) |
| A CLIPTokenizer |
| prompt (str) |
| A prompt string with weights |
| |
| Returns: |
| text_tokens (list) |
| A list contains token ids |
| text_weight (list) |
| A list contains the correspodent weight of token ids |
| |
| Example: |
| import torch |
| from transformers import CLIPTokenizer |
| |
| clip_tokenizer = CLIPTokenizer.from_pretrained( |
| "stablediffusionapi/deliberate-v2" |
| , subfolder = "tokenizer" |
| , dtype = torch.float16 |
| ) |
| |
| token_id_list, token_weight_list = get_prompts_tokens_with_weights( |
| clip_tokenizer = clip_tokenizer |
| ,prompt = "a (red:1.5) cat"*70 |
| ) |
| """ |
| texts_and_weights = parse_prompt_attention(prompt) |
| text_tokens, text_weights = [], [] |
| maxlen = clip_tokenizer.model_max_length |
| for word, weight in texts_and_weights: |
| |
| token = clip_tokenizer( |
| word, truncation=False, padding=False, add_special_tokens=False |
| ).input_ids |
| |
| |
| if debug: |
| print( |
| token, |
| "|FOR MODEL LEN{}|".format(maxlen), |
| clip_tokenizer.decode( |
| token, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| ), |
| ) |
| |
| text_tokens = [*text_tokens, *token] |
|
|
| |
| |
| chunk_weights = [weight] * len(token) |
|
|
| |
| text_weights = [*text_weights, *chunk_weights] |
| return text_tokens, text_weights |
|
|
|
|
| def group_tokens_and_weights( |
| token_ids: list, |
| weights: list, |
| pad_last_block=False, |
| bos=49406, |
| eos=49407, |
| max_length=77, |
| pad_tokens=True, |
| ): |
| """ |
| Produce tokens and weights in groups and pad the missing tokens |
| |
| Args: |
| token_ids (list) |
| The token ids from tokenizer |
| weights (list) |
| The weights list from function get_prompts_tokens_with_weights |
| pad_last_block (bool) |
| Control if fill the last token list to 75 tokens with eos |
| Returns: |
| new_token_ids (2d list) |
| new_weights (2d list) |
| |
| Example: |
| token_groups,weight_groups = group_tokens_and_weights( |
| token_ids = token_id_list |
| , weights = token_weight_list |
| ) |
| """ |
| |
| |
| |
| max_len = max_length - 2 if max_length < 77 else max_length |
| |
| new_token_ids = [] |
| new_weights = [] |
| while len(token_ids) >= max_len: |
| |
| temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)] |
| temp_77_weights = [weights.pop(0) for _ in range(max_len)] |
|
|
| |
|
|
| if pad_tokens: |
| if bos is not None: |
| temp_77_token_ids = [bos] + temp_77_token_ids + [eos] |
| temp_77_weights = [1.0] + temp_77_weights + [1.0] |
| else: |
| temp_77_token_ids = temp_77_token_ids + [eos] |
| temp_77_weights = temp_77_weights + [1.0] |
|
|
| |
| new_token_ids.append(temp_77_token_ids) |
| new_weights.append(temp_77_weights) |
|
|
| |
| if len(token_ids) > 0: |
| if pad_tokens: |
| padding_len = max_len - len(token_ids) if pad_last_block else 0 |
|
|
| temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] |
| new_token_ids.append(temp_77_token_ids) |
|
|
| temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] |
| new_weights.append(temp_77_weights) |
| else: |
| new_token_ids.append(token_ids) |
| new_weights.append(weights) |
| return new_token_ids, new_weights |
|
|
|
|
| def standardize_tensor( |
| input_tensor: torch.Tensor, target_mean: float, target_std: float |
| ) -> torch.Tensor: |
| """ |
| This function standardizes an input tensor so that it has a specific mean and standard deviation. |
| |
| Parameters: |
| input_tensor (torch.Tensor): The tensor to standardize. |
| target_mean (float): The target mean for the tensor. |
| target_std (float): The target standard deviation for the tensor. |
| |
| Returns: |
| torch.Tensor: The standardized tensor. |
| """ |
|
|
| |
| mean = input_tensor.mean() |
| std = input_tensor.std() |
|
|
| |
| standardized_tensor = (input_tensor - mean) / std |
|
|
| |
| output_tensor = standardized_tensor * target_std + target_mean |
|
|
| return output_tensor |
|
|
|
|
| def apply_weights( |
| prompt_tokens: torch.Tensor, |
| weight_tensor: torch.Tensor, |
| token_embedding: torch.Tensor, |
| eos_token_id: int, |
| pad_last_block: bool = True, |
| ) -> torch.FloatTensor: |
| mean = token_embedding.mean() |
| std = token_embedding.std() |
| if pad_last_block: |
| pooled_tensor = token_embedding[ |
| torch.arange(token_embedding.shape[0], device=token_embedding.device), |
| ( |
| prompt_tokens.to(dtype=torch.int, device=token_embedding.device) |
| == eos_token_id |
| ) |
| .int() |
| .argmax(dim=-1), |
| ] |
| else: |
| pooled_tensor = token_embedding[:, -1] |
|
|
| for j in range(len(weight_tensor)): |
| if weight_tensor[j] != 1.0: |
| token_embedding[:, j] = ( |
| pooled_tensor |
| + (token_embedding[:, j] - pooled_tensor) * weight_tensor[j] |
| ) |
| return standardize_tensor(token_embedding, mean, std) |
|
|
|
|
| @torch.inference_mode() |
| def get_weighted_text_embeddings_flux( |
| pipe: "FluxPipeline", |
| prompt: str = "", |
| num_images_per_prompt: int = 1, |
| device: Optional[torch.device] = None, |
| target_device: Optional[torch.device] = torch.device("cuda:0"), |
| target_dtype: Optional[torch.dtype] = torch.bfloat16, |
| debug: bool = False, |
| ): |
| """ |
| This function can process long prompt with weights, no length limitation |
| for Stable Diffusion XL |
| |
| Args: |
| pipe (StableDiffusionPipeline) |
| prompt (str) |
| prompt_2 (str) |
| neg_prompt (str) |
| neg_prompt_2 (str) |
| num_images_per_prompt (int) |
| device (torch.device) |
| Returns: |
| prompt_embeds (torch.Tensor) |
| neg_prompt_embeds (torch.Tensor) |
| """ |
| device = device or pipe._execution_device |
|
|
| eos = pipe.clip.tokenizer.eos_token_id |
| eos_2 = pipe.t5.tokenizer.eos_token_id |
| bos = pipe.clip.tokenizer.bos_token_id |
| bos_2 = pipe.t5.tokenizer.bos_token_id |
|
|
| clip = pipe.clip.hf_module |
| t5 = pipe.t5.hf_module |
|
|
| tokenizer_clip = pipe.clip.tokenizer |
| tokenizer_t5 = pipe.t5.tokenizer |
|
|
| t5_length = 512 if pipe.name == "flux-dev" else 256 |
| clip_length = 77 |
|
|
| |
| prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights( |
| tokenizer_clip, prompt, debug=debug |
| ) |
|
|
| |
| prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights( |
| tokenizer_t5, prompt, debug=debug |
| ) |
|
|
| prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights( |
| prompt_tokens_clip, |
| prompt_weights_clip, |
| pad_last_block=True, |
| bos=bos, |
| eos=eos, |
| max_length=clip_length, |
| ) |
| prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights( |
| prompt_tokens_t5, |
| prompt_weights_t5, |
| pad_last_block=True, |
| bos=bos_2, |
| eos=eos_2, |
| max_length=t5_length, |
| pad_tokens=False, |
| ) |
| prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped) |
| prompt_weights_t5 = flatten(prompt_weights_t5_grouped) |
| prompt_tokens_clip = flatten(prompt_tokens_clip_grouped) |
| prompt_weights_clip = flatten(prompt_weights_clip_grouped) |
|
|
| prompt_tokens_clip = tokenizer_clip.decode( |
| prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| ) |
| prompt_tokens_clip = tokenizer_clip( |
| prompt_tokens_clip, |
| add_special_tokens=True, |
| padding="max_length", |
| truncation=True, |
| max_length=clip_length, |
| return_tensors="pt", |
| ).input_ids.to(device) |
| prompt_tokens_t5 = tokenizer_t5.decode( |
| prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| ) |
| prompt_tokens_t5 = tokenizer_t5( |
| prompt_tokens_t5, |
| add_special_tokens=True, |
| padding="max_length", |
| truncation=True, |
| max_length=t5_length, |
| return_tensors="pt", |
| ).input_ids.to(device) |
|
|
| prompt_weights_t5 = torch.cat( |
| [ |
| torch.tensor(prompt_weights_t5, dtype=torch.float32), |
| torch.full( |
| (t5_length - torch.tensor(prompt_weights_t5).numel(),), |
| 1.0, |
| dtype=torch.float32, |
| ), |
| ], |
| dim=0, |
| ).to(device) |
|
|
| clip_embeds = clip( |
| prompt_tokens_clip, output_hidden_states=True, attention_mask=None |
| )["pooler_output"] |
| if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1: |
| clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) |
|
|
| weight_tensor_t5 = torch.tensor( |
| flatten(prompt_weights_t5), dtype=torch.float32, device=device |
| ) |
| t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[ |
| "last_hidden_state" |
| ] |
| t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2) |
| if debug: |
| print(t5_embeds.shape) |
| if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1: |
| t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) |
| txt_ids = torch.zeros( |
| num_images_per_prompt, |
| t5_embeds.shape[1], |
| 3, |
| device=target_device, |
| dtype=target_dtype, |
| ) |
| t5_embeds = t5_embeds.to(target_device, dtype=target_dtype) |
| clip_embeds = clip_embeds.to(target_device, dtype=target_dtype) |
|
|
| return ( |
| clip_embeds, |
| t5_embeds, |
| txt_ids, |
| ) |
|
|