| |
| |
|
|
| import torch |
| from PIL import Image |
| import numpy as np |
| from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel |
| from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler |
| import os |
| import logging |
| import time |
| import random |
| import gc |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class ModelManager: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"使用设备: {self.device}") |
|
|
| self.model_config = { |
| "caption_model": "Salesforce/blip-image-captioning-large", |
| "clip_model": "openai/clip-vit-large-patch14", |
| "sd_model": "runwayml/stable-diffusion-v1-5", |
| "controlnet_model": "lllyasviel/control_v11p_sd15_openpose" |
| } |
|
|
| self.caption_processor = None |
| self.caption_model = None |
| self.clip_processor = None |
| self.clip_model = None |
| self.sd_pipeline = None |
| self.controlnet = None |
| self.controlnet_pipeline = None |
|
|
| self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 |
| self.enable_attention_slicing = True |
| self.enable_cpu_offload = False |
|
|
| try: |
| self.load_all_models() |
| except Exception as e: |
| logger.warning(f"加载模型时出错: {e}") |
|
|
| def optimize_memory_usage(self): |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| def load_all_models(self): |
| self.optimize_memory_usage() |
| self.load_caption_model() |
| self.load_clip_model() |
| self.load_sd_pipeline() |
| self.load_controlnet_pipeline() |
| logger.info("所有模型加载完成") |
|
|
| def load_caption_model(self): |
| self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models") |
| self.caption_model = BlipForConditionalGeneration.from_pretrained( |
| self.model_config["caption_model"], |
| cache_dir="/tmp/models", |
| torch_dtype=self.torch_dtype, |
| low_cpu_mem_usage=True |
| ).to(self.device) |
| self.caption_model.enable_attention_slicing() |
| self.caption_model.eval() |
|
|
| def load_clip_model(self): |
| self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models") |
| self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device) |
| self.clip_model.eval() |
|
|
| def load_sd_pipeline(self): |
| self.sd_pipeline = StableDiffusionPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| torch_dtype=self.torch_dtype, |
| cache_dir="/tmp/models", |
| safety_checker=None, |
| requires_safety_checker=False, |
| use_safetensors=True, |
| low_cpu_mem_usage=True |
| ).to(self.device) |
| self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config) |
| if self.enable_attention_slicing: |
| self.sd_pipeline.enable_attention_slicing() |
| try: |
| self.sd_pipeline.enable_xformers_memory_efficient_attention() |
| except Exception: |
| pass |
| self.sd_pipeline.enable_vae_slicing() |
| self.sd_pipeline.safety_checker = None |
|
|
| def load_controlnet_pipeline(self): |
| self.controlnet = ControlNetModel.from_pretrained( |
| self.model_config["controlnet_model"], |
| cache_dir="/tmp/models", |
| torch_dtype=self.torch_dtype, |
| low_cpu_mem_usage=True |
| ).to(self.device) |
| self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| controlnet=self.controlnet, |
| cache_dir="/tmp/models", |
| torch_dtype=self.torch_dtype, |
| safety_checker=None, |
| requires_safety_checker=False, |
| low_cpu_mem_usage=True |
| ).to(self.device) |
| self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config) |
| if self.enable_attention_slicing: |
| self.controlnet_pipeline.enable_attention_slicing() |
| try: |
| self.controlnet_pipeline.enable_xformers_memory_efficient_attention() |
| except Exception: |
| pass |
| self.controlnet_pipeline.enable_vae_slicing() |
|
|
| @torch.no_grad() |
| def generate_caption(self, image): |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| if image.width > 512 or image.height > 512: |
| image.thumbnail((512, 512), Image.Resampling.LANCZOS) |
| inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) |
| outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True) |
| caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True) |
| del inputs, outputs |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return caption |
|
|
| @torch.no_grad() |
| def analyze_style(self, image): |
| style_labels = [ |
| "business formal suit professional attire", |
| "casual comfortable everyday wear", |
| "athletic sportswear activewear", |
| "fashion trendy modern stylish", |
| "vintage retro classic style", |
| "streetwear urban contemporary", |
| "elegant sophisticated refined" |
| ] |
| style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"] |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| if image.width > 224 or image.height > 224: |
| image.thumbnail((224, 224), Image.Resampling.LANCZOS) |
| inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device) |
| outputs = self.clip_model(**inputs) |
| probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0] |
| return {name: float(prob) for name, prob in zip(style_names, probs)} |
|
|
| @torch.no_grad() |
| def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None): |
| if negative_prompt is None: |
| negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed" |
| width = (width // 8) * 8 |
| height = (height // 8) * 8 |
| gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None |
| result = self.sd_pipeline( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| height=height, |
| width=width, |
| generator=gen |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return result.images[0] |
|
|
| @torch.no_grad() |
| def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768): |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| control_image = image.resize((512, 768), Image.Resampling.LANCZOS) |
| if negative_prompt is None: |
| negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people" |
| prompt_with_angle = f"{prompt}, view from {angle} degrees" |
| if reference_image is not None: |
| prompt_with_angle = f"{prompt_with_angle}, based on provided reference design" |
| gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle)) |
| result = self.controlnet_pipeline( |
| prompt=prompt_with_angle, |
| image=control_image, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| controlnet_conditioning_scale=1.0, |
| generator=gen |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return result.images[0] |
|
|
| def create_placeholder_image(self, width, height): |
| color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)]) |
| return Image.new('RGB', (width, height), color=color) |
|
|
| def cleanup(self): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| try: |
| torch.cuda.ipc_collect() |
| except Exception: |
| pass |
|
|
| def get_model_status(self): |
| status = { |
| "caption_model": self.caption_model is not None, |
| "clip_model": self.clip_model is not None, |
| "sd_pipeline": self.sd_pipeline is not None, |
| "controlnet_pipeline": self.controlnet_pipeline is not None, |
| "device": self.device |
| } |
| if torch.cuda.is_available(): |
| status["gpu_memory"] = { |
| "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB", |
| "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB" |
| } |
| return status |
|
|