| |
| import torch |
| from PIL import Image |
| from transformers import ( |
| BlipProcessor, |
| BlipForConditionalGeneration, |
| CLIPProcessor, |
| CLIPModel |
| ) |
| from diffusers import ( |
| StableDiffusionPipeline, |
| StableDiffusionControlNetPipeline, |
| ControlNetModel, |
| EulerAncestralDiscreteScheduler |
| ) |
| import numpy as np |
| import gc |
| import os |
| import logging |
| import time |
| from typing import Optional, Dict, List, Tuple |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class ModelManager: |
| def __init__(self): |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"使用设备: {self.device}") |
| |
| |
| self.caption_model = None |
| self.caption_processor = None |
| self.clip_model = None |
| self.clip_processor = None |
| self.sd_pipeline = None |
| self.controlnet_pipeline = None |
| self.controlnet = None |
| |
| |
| self.model_config = { |
| "caption_model": "Salesforce/blip-image-captioning-base", |
| "clip_model": "openai/clip-vit-base-patch32", |
| "sd_model": "stabilityai/stable-diffusion-2-1-base", |
| "controlnet_model": "lllyasviel/sd-controlnet-openpose" |
| } |
| |
| |
| self.cache_dir = "/tmp/models" |
| os.makedirs(self.cache_dir, exist_ok=True) |
| logger.info(f"模型缓存目录: {self.cache_dir}") |
| |
| |
| self.load_times = {} |
| self.last_used = {} |
| |
| def load_caption_model(self): |
| """加载图像描述模型""" |
| if self.caption_model is None: |
| start_time = time.time() |
| logger.info("正在加载图像描述模型...") |
| |
| try: |
| self.caption_processor = BlipProcessor.from_pretrained( |
| self.model_config["caption_model"], |
| cache_dir=self.cache_dir |
| ) |
| |
| self.caption_model = BlipForConditionalGeneration.from_pretrained( |
| self.model_config["caption_model"], |
| cache_dir=self.cache_dir, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
| ).to(self.device) |
| |
| |
| if self.device == "cuda": |
| self.caption_model = self.caption_model.half() |
| |
| logger.info("图像描述模型加载完成") |
| self.load_times["caption"] = time.time() - start_time |
| self.last_used["caption"] = time.time() |
| except Exception as e: |
| logger.error(f"加载描述模型失败: {str(e)}") |
| |
| self.model_config["caption_model"] = "Salesforce/blip-image-captioning-base" |
| self.load_caption_model() |
| |
| def load_clip_model(self): |
| """加载CLIP模型用于风格分析""" |
| if self.clip_model is None: |
| start_time = time.time() |
| logger.info("正在加载CLIP模型...") |
| |
| try: |
| self.clip_processor = CLIPProcessor.from_pretrained( |
| self.model_config["clip_model"], |
| cache_dir=self.cache_dir |
| ) |
| |
| self.clip_model = CLIPModel.from_pretrained( |
| self.model_config["clip_model"], |
| cache_dir=self.cache_dir, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
| ).to(self.device) |
| |
| |
| if self.device == "cuda": |
| self.clip_model = self.clip_model.half() |
| |
| logger.info("CLIP模型加载完成") |
| self.load_times["clip"] = time.time() - start_time |
| self.last_used["clip"] = time.time() |
| except Exception as e: |
| logger.error(f"加载CLIP模型失败: {str(e)}") |
| |
| def load_sd_pipeline(self): |
| """加载Stable Diffusion生成管道""" |
| if self.sd_pipeline is None: |
| start_time = time.time() |
| logger.info("正在加载Stable Diffusion模型...") |
| |
| |
| if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 10 * 1024**3: |
| logger.info("检测到有限GPU内存,使用更小的SD模型") |
| self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5" |
| |
| try: |
| self.sd_pipeline = StableDiffusionPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| cache_dir=self.cache_dir, |
| safety_checker=None, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
| ).to(self.device) |
| |
| |
| if self.device == "cuda": |
| try: |
| |
| self.sd_pipeline.enable_xformers_memory_efficient_attention() |
| except: |
| logger.warning("无法启用xformers,使用回退方案") |
| |
| |
| self.sd_pipeline.enable_attention_slicing() |
| |
| logger.info("Stable Diffusion模型加载完成") |
| self.load_times["sd"] = time.time() - start_time |
| self.last_used["sd"] = time.time() |
| except Exception as e: |
| logger.error(f"加载SD模型失败: {str(e)}") |
| |
| self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5" |
| self.load_sd_pipeline() |
| |
| def load_controlnet_pipeline(self): |
| """加载ControlNet管道用于3D试穿""" |
| if self.controlnet_pipeline is None: |
| start_time = time.time() |
| logger.info("正在加载ControlNet模型...") |
| |
| try: |
| |
| self.controlnet = ControlNetModel.from_pretrained( |
| self.model_config["controlnet_model"], |
| cache_dir=self.cache_dir, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
| ) |
| |
| |
| self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| controlnet=self.controlnet, |
| cache_dir=self.cache_dir, |
| safety_checker=None, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
| ).to(self.device) |
| |
| |
| self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
| self.controlnet_pipeline.scheduler.config |
| ) |
| |
| |
| if self.device == "cuda": |
| try: |
| self.controlnet_pipeline.enable_xformers_memory_efficient_attention() |
| except: |
| logger.warning("无法为ControlNet启用xformers") |
| |
| self.controlnet_pipeline.enable_attention_slicing() |
| |
| logger.info("ControlNet模型加载完成") |
| self.load_times["controlnet"] = time.time() - start_time |
| self.last_used["controlnet"] = time.time() |
| except Exception as e: |
| logger.error(f"加载ControlNet模型失败: {str(e)}") |
| |
| def generate_caption(self, image: Image.Image) -> str: |
| """为图像生成描述性标题""" |
| try: |
| self.load_caption_model() |
| self.last_used["caption"] = time.time() |
| |
| |
| inputs = self.caption_processor( |
| images=image, |
| return_tensors="pt" |
| ).to(self.device, torch.float16 if self.device == "cuda" else torch.float32) |
| |
| |
| output = self.caption_model.generate(**inputs, max_length=50) |
| caption = self.caption_processor.decode(output[0], skip_special_tokens=True) |
| |
| logger.info(f"生成的标题: {caption}") |
| return caption |
| |
| except Exception as e: |
| logger.error(f"生成标题失败: {str(e)}") |
| |
| return "时尚服装设计" |
| |
| def analyze_style(self, image: Image.Image) -> Dict[str, float]: |
| """使用CLIP分析图像风格""" |
| try: |
| self.load_clip_model() |
| self.last_used["clip"] = time.time() |
| |
| |
| style_labels = [ |
| "商务正装", "休闲风", "运动风", "时尚潮流", |
| "复古风", "街头风", "优雅风", "民族风" |
| ] |
| |
| |
| inputs = self.clip_processor( |
| text=style_labels, |
| images=image, |
| return_tensors="pt", |
| padding=True |
| ).to(self.device) |
| |
| |
| outputs = self.clip_model(**inputs) |
| logits_per_image = outputs.logits_per_image |
| probs = logits_per_image.softmax(dim=1).detach().cpu().numpy()[0] |
| |
| |
| top3_idx = np.argsort(probs)[-3:][::-1] |
| top_styles = { |
| style_labels[i]: float(probs[i]) for i in top3_idx |
| } |
| |
| logger.info(f"风格分析结果: {top_styles}") |
| return top_styles |
| |
| except Exception as e: |
| logger.error(f"风格分析失败: {str(e)}") |
| |
| return {"休闲风": 0.8, "时尚潮流": 0.7} |
| |
| def generate_image( |
| self, |
| prompt: str, |
| negative_prompt: str = "", |
| num_inference_steps: int = 30, |
| guidance_scale: float = 7.5, |
| height: int = 512, |
| width: int = 512 |
| ) -> Image.Image: |
| """根据提示生成设计图像""" |
| try: |
| self.load_sd_pipeline() |
| self.last_used["sd"] = time.time() |
| |
| |
| with torch.autocast("cuda" if self.device == "cuda" else "cpu"): |
| image = self.sd_pipeline( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| height=height, |
| width=width |
| ).images[0] |
| |
| logger.info(f"成功生成设计图像: {prompt[:50]}...") |
| return image |
| |
| except Exception as e: |
| logger.error(f"生成设计图像失败: {str(e)}") |
| |
| return Image.new('RGB', (512, 512), color=(220, 220, 220)) |
| |
| def generate_controlnet_image( |
| self, |
| image: Image.Image, |
| prompt: str, |
| negative_prompt: str = "", |
| num_inference_steps: int = 35, |
| guidance_scale: float = 8.0 |
| ) -> Image.Image: |
| """使用ControlNet生成3D试穿图像""" |
| try: |
| self.load_controlnet_pipeline() |
| self.last_used["controlnet"] = time.time() |
| |
| |
| with torch.autocast("cuda" if self.device == "cuda" else "cpu"): |
| image = self.controlnet_pipeline( |
| prompt=prompt, |
| image=image, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| controlnet_conditioning_scale=0.8 |
| ).images[0] |
| |
| logger.info(f"成功生成3D试穿图像") |
| return image |
| |
| except Exception as e: |
| logger.error(f"生成3D试穿图像失败: {str(e)}") |
| |
| return self.generate_image( |
| prompt, |
| negative_prompt, |
| num_inference_steps |
| ) |
| |
| def unload_model(self, model_type: str): |
| """卸载指定类型的模型以释放内存""" |
| logger.info(f"卸载模型: {model_type}") |
| |
| if model_type == "caption" and self.caption_model is not None: |
| del self.caption_model |
| del self.caption_processor |
| self.caption_model = None |
| self.caption_processor = None |
| logger.info("卸载图像描述模型") |
| |
| elif model_type == "clip" and self.clip_model is not None: |
| del self.clip_model |
| del self.clip_processor |
| self.clip_model = None |
| self.clip_processor = None |
| logger.info("卸载CLIP模型") |
| |
| elif model_type == "sd" and self.sd_pipeline is not None: |
| del self.sd_pipeline |
| self.sd_pipeline = None |
| logger.info("卸载Stable Diffusion模型") |
| |
| elif model_type == "controlnet" and self.controlnet_pipeline is not None: |
| del self.controlnet_pipeline |
| del self.controlnet |
| self.controlnet_pipeline = None |
| self.controlnet = None |
| logger.info("卸载ControlNet模型") |
| |
| |
| self.cleanup_memory() |
| |
| def cleanup(self): |
| """清理所有模型释放内存""" |
| logger.info("清理所有模型释放内存...") |
| |
| |
| if self.caption_model is not None: |
| del self.caption_model |
| if self.caption_processor is not None: |
| del self.caption_processor |
| if self.clip_model is not None: |
| del self.clip_model |
| if self.clip_processor is not None: |
| del self.clip_processor |
| if self.sd_pipeline is not None: |
| del self.sd_pipeline |
| if self.controlnet_pipeline is not None: |
| del self.controlnet_pipeline |
| if self.controlnet is not None: |
| del self.controlnet |
| |
| |
| self.caption_model = None |
| self.caption_processor = None |
| self.clip_model = None |
| self.clip_processor = None |
| self.sd_pipeline = None |
| self.controlnet_pipeline = None |
| self.controlnet = None |
| |
| |
| self.cleanup_memory() |
| logger.info("内存清理完成") |
| |
| def cleanup_memory(self): |
| """执行内存清理操作""" |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| gc.collect() |
| |
| def get_memory_usage(self) -> Dict[str, float]: |
| """获取当前内存使用情况""" |
| mem_info = {} |
| |
| if torch.cuda.is_available(): |
| mem_info["gpu_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| mem_info["gpu_used"] = torch.cuda.memory_allocated() / (1024**3) |
| mem_info["gpu_free"] = mem_info["gpu_total"] - mem_info["gpu_used"] |
| |
| return mem_info |
| |
| def get_model_status(self) -> Dict[str, str]: |
| """获取模型加载状态""" |
| status = { |
| "caption_model": "已加载" if self.caption_model else "未加载", |
| "clip_model": "已加载" if self.clip_model else "未加载", |
| "sd_model": "已加载" if self.sd_pipeline else "未加载", |
| "controlnet_model": "已加载" if self.controlnet_pipeline else "未加载" |
| } |
| |
| |
| for model in ["caption", "clip", "sd", "controlnet"]: |
| if model in self.load_times: |
| status[f"{model}_load_time"] = f"{self.load_times[model]:.2f}秒" |
| if model in self.last_used: |
| mins_ago = (time.time() - self.last_used[model]) / 60 |
| status[f"{model}_last_used"] = f"{mins_ago:.1f}分钟前" |
| |
| return status |
| |
| def __del__(self): |
| """析构函数确保资源释放""" |
| self.cleanup() |