# models/model_manager.py import torch from PIL import Image from transformers import ( BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel ) from diffusers import ( StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler ) import numpy as np import gc import os import logging import time from typing import Optional, Dict, List, Tuple # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): # 自动检测设备 self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {self.device}") # 初始化模型为空 self.caption_model = None self.caption_processor = None self.clip_model = None self.clip_processor = None self.sd_pipeline = None self.controlnet_pipeline = None self.controlnet = None # 模型配置 - 使用较小的模型变体以适应 Space 环境 self.model_config = { "caption_model": "Salesforce/blip-image-captioning-base", # 基础版节省内存 "clip_model": "openai/clip-vit-base-patch32", # 基础版CLIP "sd_model": "stabilityai/stable-diffusion-2-1-base", # SD 2.1基础版 "controlnet_model": "lllyasviel/sd-controlnet-openpose" # 姿势控制模型 } # 创建缓存目录 - 使用Space的临时目录 self.cache_dir = "/tmp/models" os.makedirs(self.cache_dir, exist_ok=True) logger.info(f"模型缓存目录: {self.cache_dir}") # 加载统计 self.load_times = {} self.last_used = {} def load_caption_model(self): """加载图像描述模型""" if self.caption_model is None: start_time = time.time() logger.info("正在加载图像描述模型...") try: self.caption_processor = BlipProcessor.from_pretrained( self.model_config["caption_model"], cache_dir=self.cache_dir ) self.caption_model = BlipForConditionalGeneration.from_pretrained( self.model_config["caption_model"], cache_dir=self.cache_dir, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # 模型优化 if self.device == "cuda": self.caption_model = self.caption_model.half() logger.info("图像描述模型加载完成") self.load_times["caption"] = time.time() - start_time self.last_used["caption"] = time.time() except Exception as e: logger.error(f"加载描述模型失败: {str(e)}") # 尝试回退到更小的模型 self.model_config["caption_model"] = "Salesforce/blip-image-captioning-base" self.load_caption_model() def load_clip_model(self): """加载CLIP模型用于风格分析""" if self.clip_model is None: start_time = time.time() logger.info("正在加载CLIP模型...") try: self.clip_processor = CLIPProcessor.from_pretrained( self.model_config["clip_model"], cache_dir=self.cache_dir ) self.clip_model = CLIPModel.from_pretrained( self.model_config["clip_model"], cache_dir=self.cache_dir, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # 模型优化 if self.device == "cuda": self.clip_model = self.clip_model.half() logger.info("CLIP模型加载完成") self.load_times["clip"] = time.time() - start_time self.last_used["clip"] = time.time() except Exception as e: logger.error(f"加载CLIP模型失败: {str(e)}") def load_sd_pipeline(self): """加载Stable Diffusion生成管道""" if self.sd_pipeline is None: start_time = time.time() logger.info("正在加载Stable Diffusion模型...") # 根据可用内存选择模型变体 if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 10 * 1024**3: logger.info("检测到有限GPU内存,使用更小的SD模型") self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5" try: self.sd_pipeline = StableDiffusionPipeline.from_pretrained( self.model_config["sd_model"], cache_dir=self.cache_dir, safety_checker=None, # 禁用安全检查以节省内存 torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # 优化性能 if self.device == "cuda": try: # 启用内存高效注意力 self.sd_pipeline.enable_xformers_memory_efficient_attention() except: logger.warning("无法启用xformers,使用回退方案") # 启用注意力切片 self.sd_pipeline.enable_attention_slicing() logger.info("Stable Diffusion模型加载完成") self.load_times["sd"] = time.time() - start_time self.last_used["sd"] = time.time() except Exception as e: logger.error(f"加载SD模型失败: {str(e)}") # 尝试回退到更小的模型 self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5" self.load_sd_pipeline() def load_controlnet_pipeline(self): """加载ControlNet管道用于3D试穿""" if self.controlnet_pipeline is None: start_time = time.time() logger.info("正在加载ControlNet模型...") try: # 先加载ControlNet模型 self.controlnet = ControlNetModel.from_pretrained( self.model_config["controlnet_model"], cache_dir=self.cache_dir, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) # 然后创建ControlNet管道 self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( self.model_config["sd_model"], controlnet=self.controlnet, cache_dir=self.cache_dir, safety_checker=None, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) # 设置调度器 self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( self.controlnet_pipeline.scheduler.config ) # 优化性能 if self.device == "cuda": try: self.controlnet_pipeline.enable_xformers_memory_efficient_attention() except: logger.warning("无法为ControlNet启用xformers") self.controlnet_pipeline.enable_attention_slicing() logger.info("ControlNet模型加载完成") self.load_times["controlnet"] = time.time() - start_time self.last_used["controlnet"] = time.time() except Exception as e: logger.error(f"加载ControlNet模型失败: {str(e)}") def generate_caption(self, image: Image.Image) -> str: """为图像生成描述性标题""" try: self.load_caption_model() self.last_used["caption"] = time.time() # 准备输入 inputs = self.caption_processor( images=image, return_tensors="pt" ).to(self.device, torch.float16 if self.device == "cuda" else torch.float32) # 生成标题 output = self.caption_model.generate(**inputs, max_length=50) caption = self.caption_processor.decode(output[0], skip_special_tokens=True) logger.info(f"生成的标题: {caption}") return caption except Exception as e: logger.error(f"生成标题失败: {str(e)}") # 返回默认标题 return "时尚服装设计" def analyze_style(self, image: Image.Image) -> Dict[str, float]: """使用CLIP分析图像风格""" try: self.load_clip_model() self.last_used["clip"] = time.time() # 定义风格类别 style_labels = [ "商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风", "民族风" ] # 准备输入 inputs = self.clip_processor( text=style_labels, images=image, return_tensors="pt", padding=True ).to(self.device) # 获取预测 outputs = self.clip_model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).detach().cpu().numpy()[0] # 获取前3个风格 top3_idx = np.argsort(probs)[-3:][::-1] top_styles = { style_labels[i]: float(probs[i]) for i in top3_idx } logger.info(f"风格分析结果: {top_styles}") return top_styles except Exception as e: logger.error(f"风格分析失败: {str(e)}") # 返回默认风格 return {"休闲风": 0.8, "时尚潮流": 0.7} def generate_image( self, prompt: str, negative_prompt: str = "", num_inference_steps: int = 30, guidance_scale: float = 7.5, height: int = 512, width: int = 512 ) -> Image.Image: """根据提示生成设计图像""" try: self.load_sd_pipeline() self.last_used["sd"] = time.time() # 生成图像 with torch.autocast("cuda" if self.device == "cuda" else "cpu"): image = self.sd_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width ).images[0] logger.info(f"成功生成设计图像: {prompt[:50]}...") return image except Exception as e: logger.error(f"生成设计图像失败: {str(e)}") # 创建占位图像 return Image.new('RGB', (512, 512), color=(220, 220, 220)) def generate_controlnet_image( self, image: Image.Image, prompt: str, negative_prompt: str = "", num_inference_steps: int = 35, guidance_scale: float = 8.0 ) -> Image.Image: """使用ControlNet生成3D试穿图像""" try: self.load_controlnet_pipeline() self.last_used["controlnet"] = time.time() # 生成图像 with torch.autocast("cuda" if self.device == "cuda" else "cpu"): image = self.controlnet_pipeline( prompt=prompt, image=image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=0.8 ).images[0] logger.info(f"成功生成3D试穿图像") return image except Exception as e: logger.error(f"生成3D试穿图像失败: {str(e)}") # 回退到普通SD模型 return self.generate_image( prompt, negative_prompt, num_inference_steps ) def unload_model(self, model_type: str): """卸载指定类型的模型以释放内存""" logger.info(f"卸载模型: {model_type}") if model_type == "caption" and self.caption_model is not None: del self.caption_model del self.caption_processor self.caption_model = None self.caption_processor = None logger.info("卸载图像描述模型") elif model_type == "clip" and self.clip_model is not None: del self.clip_model del self.clip_processor self.clip_model = None self.clip_processor = None logger.info("卸载CLIP模型") elif model_type == "sd" and self.sd_pipeline is not None: del self.sd_pipeline self.sd_pipeline = None logger.info("卸载Stable Diffusion模型") elif model_type == "controlnet" and self.controlnet_pipeline is not None: del self.controlnet_pipeline del self.controlnet self.controlnet_pipeline = None self.controlnet = None logger.info("卸载ControlNet模型") # 清理内存 self.cleanup_memory() def cleanup(self): """清理所有模型释放内存""" logger.info("清理所有模型释放内存...") # 释放所有模型 if self.caption_model is not None: del self.caption_model if self.caption_processor is not None: del self.caption_processor if self.clip_model is not None: del self.clip_model if self.clip_processor is not None: del self.clip_processor if self.sd_pipeline is not None: del self.sd_pipeline if self.controlnet_pipeline is not None: del self.controlnet_pipeline if self.controlnet is not None: del self.controlnet # 重置引用 self.caption_model = None self.caption_processor = None self.clip_model = None self.clip_processor = None self.sd_pipeline = None self.controlnet_pipeline = None self.controlnet = None # 清理内存 self.cleanup_memory() logger.info("内存清理完成") def cleanup_memory(self): """执行内存清理操作""" # 清理CUDA缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() # 执行垃圾回收 gc.collect() def get_memory_usage(self) -> Dict[str, float]: """获取当前内存使用情况""" mem_info = {} if torch.cuda.is_available(): mem_info["gpu_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3) mem_info["gpu_used"] = torch.cuda.memory_allocated() / (1024**3) mem_info["gpu_free"] = mem_info["gpu_total"] - mem_info["gpu_used"] return mem_info def get_model_status(self) -> Dict[str, str]: """获取模型加载状态""" status = { "caption_model": "已加载" if self.caption_model else "未加载", "clip_model": "已加载" if self.clip_model else "未加载", "sd_model": "已加载" if self.sd_pipeline else "未加载", "controlnet_model": "已加载" if self.controlnet_pipeline else "未加载" } # 添加加载时间信息 for model in ["caption", "clip", "sd", "controlnet"]: if model in self.load_times: status[f"{model}_load_time"] = f"{self.load_times[model]:.2f}秒" if model in self.last_used: mins_ago = (time.time() - self.last_used[model]) / 60 status[f"{model}_last_used"] = f"{mins_ago:.1f}分钟前" return status def __del__(self): """析构函数确保资源释放""" self.cleanup()