| import os |
| |
| import torch |
|
|
| from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
| from .model import DrUM as backbone |
| from .sampling import coreset_sampling |
|
|
| def stable_diffusion(large): |
| """ |
| openai/clip-vit-large-patch14, CLIPTextModel, skip -1 |
| """ |
| def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs): |
| return large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None |
| return inference |
|
|
| def stable_diffusion_v2(huge): |
| """ |
| openai/clip-vit-huge-patch14, CLIPTextModel, skip -1 |
| """ |
| def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs): |
| return huge(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None |
| return inference |
|
|
| def stable_diffusion_xl(large, bigG): |
| """ |
| openai/clip-vit-large-patch14, CLIPTextModel, skip -2, unnorm |
| laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
| """ |
| def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs): |
| hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
| if skip == -1: |
| hidden_state2, pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
| else: |
| hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
| pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
| hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1) |
| pool_hidden_state = bigG.projection_text_hidden_state(pool_hidden_state) |
| return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
| return inference |
|
|
| def stable_diffusion_v3(large, bigG, t5): |
| """ |
| openai/clip-vit-large-patch14, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
| laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
| t5-v1_1-xxl, T5EncoderModel |
| """ |
| def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs): |
| if skip == -1: |
| hidden_state, pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
| hidden_state2, pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
| else: |
| hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
| hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
| pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
| pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
| hidden_state3 = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs) |
| hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1) |
| pool_hidden_state = large.projection_text_hidden_state(pool_hidden_state) |
| pool_hidden_state2 = bigG.projection_text_hidden_state(pool_hidden_state2) |
| hidden_state = torch.nn.functional.pad(hidden_state, (0, hidden_state3.shape[-1] - hidden_state.shape[-1])) |
| hidden_state = torch.cat([hidden_state, hidden_state3], dim = -2) |
| pool_hidden_state = torch.cat([pool_hidden_state, pool_hidden_state2], dim = -1) |
| return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
| return inference |
|
|
| def flux(large, t5): |
| """ |
| openai/clip-vit-large-patch14, CLIPTextModel, pooling |
| t5-v1_1-xxl, T5EncoderModel |
| """ |
| def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = None, batch_size = 64, **kwargs): |
| hidden_state = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs) |
| pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
| return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
| return inference |
|
|
| def peca(pipeline, save_path = "./weight", n_layer = 10): |
| if os.path.exists(os.path.join(save_path, "L.pth")) or os.path.exists(os.path.join(save_path, "H.pth")): |
| load_func = torch.load |
| postfix = "pth" |
| else: |
| from safetensors.torch import load_file as load_func |
| postfix = "safetensors" |
| |
| if "flux" in pipeline.config._name_or_path.split("/")[-1].lower(): |
| model = pipeline.text_encoder |
| processor = pipeline.tokenizer |
| model2 = pipeline.text_encoder_2 |
| processor2 = pipeline.tokenizer_2 |
|
|
| large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
| t5 = backbone(model2, processor2, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix)))) |
| empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
| empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| t5.adapter.set_base_query(empty) |
|
|
| feature_encoder = large |
| encoder = flux(large, t5) |
| size = 1024 |
| num_inference_steps = 28 |
| skip = -2 |
| elif "stable-diffusion-3.5" in pipeline.config._name_or_path.split("/")[-1].lower(): |
| model = pipeline.text_encoder |
| processor = pipeline.tokenizer |
| model2 = pipeline.text_encoder_2 |
| processor2 = pipeline.tokenizer_2 |
| model3 = pipeline.text_encoder_3 |
| processor3 = pipeline.tokenizer_3 |
|
|
| large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
| bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix)))) |
| t5 = backbone(model3, processor3, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix)))) |
| empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
| empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
| empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| t5.adapter.set_base_query(empty) |
|
|
| feature_encoder = large |
| encoder = stable_diffusion_v3(large, bigG, t5) |
| size = 1024 |
| num_inference_steps = 28 |
| skip = -2 |
| elif "xl-base" in pipeline.config._name_or_path.split("/")[-1].lower(): |
| model = pipeline.text_encoder |
| processor = pipeline.tokenizer |
| model2 = pipeline.text_encoder_2 |
| processor2 = pipeline.tokenizer_2 |
|
|
| large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
| bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix)))) |
| empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
| empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
| feature_encoder = large |
| encoder = stable_diffusion_xl(large, bigG) |
| size = 1024 |
| num_inference_steps = 50 |
| skip = -2 |
| elif "stable-diffusion-2" in pipeline.config._name_or_path.split("/")[-1].lower(): |
| model = pipeline.text_encoder |
| processor = pipeline.tokenizer |
|
|
| huge = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| huge.adapter.load_state_dict(load_func(os.path.join(save_path, "H.{0}".format(postfix)))) |
| empty, pool = huge.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| huge.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
| feature_encoder = huge |
| encoder = stable_diffusion_v2(huge) |
| size = 768 |
| num_inference_steps = 50 |
| skip = -1 |
| else: |
| model = pipeline.text_encoder |
| processor = pipeline.tokenizer |
|
|
| large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
| large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
| empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
| large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
| feature_encoder = large |
| encoder = stable_diffusion(large) |
| size = 512 |
| num_inference_steps = 50 |
| skip = -1 |
| return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip |
|
|
| class DrUM(DiffusionPipeline): |
| def __init__(self, pipeline, repo_id = "Burf/DrUM", weight = None, torch_dtype = torch.bfloat16, device = "cuda"): |
| """ |
| DrUM for various T2I diffusion models |
| """ |
| self.pipeline = pipeline if not isinstance(pipeline, str) else self.load_pipeline(pipeline, torch_dtype = torch_dtype, device = device) |
| self.repo_id = repo_id |
| |
| self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(self.pipeline, repo_id, weight) |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda", weight=None): |
| """ |
| Load DrUM adapter with appropriate pipeline |
| """ |
| pipeline = cls.load_pipeline(pretrained_model_name_or_path, torch_dtype, device) |
| return cls(pipeline = pipeline, repo_id = repo_id, weight = weight, torch_dtype = torch_dtype, device = device) |
| |
| @staticmethod |
| def load_pipeline(model_id, torch_dtype = torch.bfloat16, device = "cuda"): |
| name = model_id.split("/")[-1].lower() |
| if "flux" in name: |
| pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
| elif "stable-diffusion-3.5" in name: |
| pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
| else: |
| pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
| |
| pipeline = pipeline.to(device if torch.cuda.is_available() else "cpu") |
| |
| return pipeline |
| |
| def load_weight(self, pipeline, repo_id = "Burf/DrUM", weight = None): |
| name = pipeline.config._name_or_path.split("/")[-1].lower() |
| |
| weights = [] |
| if "flux" in name: |
| weights = ["L.safetensors", "T5.safetensors"] |
| elif "stable-diffusion-3.5" in name: |
| weights = ["L.safetensors", "bigG.safetensors", "T5.safetensors"] |
| elif "xl-base" in name: |
| weights = ["L.safetensors", "bigG.safetensors"] |
| elif "stable-diffusion-2" in name: |
| weights = ["H.safetensors"] |
| else: |
| weights = ["L.safetensors"] |
| |
| for weight_file in weights: |
| if isinstance(weight, str) and os.path.exists(os.path.join(weight, weight_file)): |
| weight_path = weight |
| break |
| else: |
| safetensor_path = hf_hub_download(repo_id = repo_id, filename = "weight/" + weight_file) |
| weight_path = os.path.dirname(safetensor_path) |
| return weight_path |
| |
| def load_peca(self, pipeline, repo_id = "Burf/DrUM", weight = None): |
| adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id, weight)) |
| return adapter, feature_encoder, size, num_inference_steps, skip |
| |
| def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42, |
| size = None, num_inference_steps = None, num_images_per_prompt = 1): |
| """ |
| Generate images using DrUM adapter |
| |
| Args: |
| prompt: Text prompt for generation |
| ref: Reference prompts (list of strings) |
| weight: Weights for reference prompts (list of floats) |
| alpha: Personalization strength (0-1) |
| skip: Text condition axis |
| sampling: Whether to use coreset sampling for reference selection (default: False) |
| seed: Random seed |
| size: Image size |
| num_inference_steps: Inference steps |
| num_images_per_prompt: Number of images to generate |
| |
| Returns: |
| Personalized images (list of PIL Images) |
| """ |
| size = self.size if size is None else size |
| num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps |
| skip = self.skip if skip is None else skip |
| |
| if sampling and isinstance(ref, (tuple, list)) and 1 < len(ref): |
| import numpy as np |
| |
| with torch.no_grad(): |
| feature = self.feature_encoder(ref).cpu().float().numpy() |
| |
| indices = coreset_sampling(feature, weight = weight, seed = seed) |
| ref = np.array(ref)[indices].tolist() |
| |
| if isinstance(weight, (tuple, list)) and len(weight) == len(ref): |
| weight = np.array(weight)[indices].tolist() |
| |
| generator = torch.Generator(self.pipeline.device).manual_seed(seed) |
| with torch.no_grad(): |
| cond, pool_cond = self.adapter(prompt, ref, weight = weight, alpha = alpha, skip = skip) |
| |
| pipe_kwargs = { |
| "num_images_per_prompt": num_images_per_prompt, |
| "num_inference_steps": num_inference_steps, |
| "generator": generator, |
| "height": size, |
| "width": size |
| } |
| |
| pipe_kwargs["prompt_embeds"] = cond.type(self.pipeline.dtype) |
| if pool_cond is not None: |
| pipe_kwargs["pooled_prompt_embeds"] = pool_cond.type(self.pipeline.dtype) |
| |
| name = self.pipeline.config._name_or_path.split("/")[-1].lower() |
| if "flux" in name or "stable-diffusion-3" in name: |
| pipe_kwargs["max_sequence_length"] = 256 |
| |
| images = self.pipeline(**pipe_kwargs).images |
| return images |