import torch from PIL import Image import numpy as np from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler import os import logging import time import random # 补充导入 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {self.device}") # 模型配置 - 使用更精细的3D模型 self.model_config = { "caption_model": "Salesforce/blip-image-captioning-large", "clip_model": "openai/clip-vit-large-patch14", "sd_model": "runwayml/stable-diffusion-v1-5", "controlnet_model": "lllyasviel/control_v11p_sd15_openpose" } # 模型容器 self.caption_processor = None self.caption_model = None self.clip_processor = None self.clip_model = None self.sd_pipeline = None self.controlnet = None self.controlnet_pipeline = None # 预加载所有模型 self.load_all_models() def load_all_models(self): self.load_caption_model() self.load_clip_model() self.load_sd_pipeline() self.load_controlnet_pipeline() def load_caption_model(self): try: logger.info("加载 BLIP 图像描述模型...") self.caption_processor = BlipProcessor.from_pretrained( self.model_config["caption_model"], cache_dir="/tmp/models" ) self.caption_model = BlipForConditionalGeneration.from_pretrained( self.model_config["caption_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) logger.info("BLIP 模型加载完成") except Exception as e: logger.error(f"BLIP 模型加载失败: {e}") def load_clip_model(self): try: logger.info("加载 CLIP 模型...") self.clip_processor = CLIPProcessor.from_pretrained( self.model_config["clip_model"], cache_dir="/tmp/models" ) self.clip_model = CLIPModel.from_pretrained( self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) logger.info("CLIP 模型加载完成") except Exception as e: logger.error(f"CLIP 模型加载失败: {e}") def load_sd_pipeline(self): try: logger.info("加载 Stable Diffusion Pipeline...") self.sd_pipeline = StableDiffusionPipeline.from_pretrained( self.model_config["sd_model"], torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, cache_dir="/tmp/models", safety_checker=None, use_safetensors=True ) self.sd_pipeline = self.sd_pipeline.to(self.device) self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config) logger.info("Stable Diffusion Pipeline 加载完成") except Exception as e: logger.error(f"Stable Diffusion Pipeline 加载失败: {e}") def load_controlnet_pipeline(self): try: logger.info("加载 ControlNet 模型和 Pipeline...") self.controlnet = ControlNetModel.from_pretrained( self.model_config["controlnet_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( self.model_config["sd_model"], controlnet=self.controlnet, cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, safety_checker=None ).to(self.device) self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config) logger.info("ControlNet Pipeline 加载完成") except Exception as e: logger.error(f"ControlNet Pipeline 加载失败: {e}") # 下面是真正调用模型的接口 def generate_caption(self, image): """使用BLIP模型生成图像描述""" if self.caption_model is None or self.caption_processor is None: self.load_caption_model() inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.caption_model.generate(**inputs, max_length=50) caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True) return caption def analyze_style(self, image): """使用CLIP模型分析服装风格""" if self.clip_model is None or self.clip_processor is None: self.load_clip_model() styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"] inputs = self.clip_processor( text=styles, images=image, return_tensors="pt", padding=True ).to(self.device) with torch.no_grad(): outputs = self.clip_model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] style_scores = {style: float(prob) for style, prob in zip(styles, probs)} return style_scores def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512): """使用Stable Diffusion生成设计图像""" if self.sd_pipeline is None: self.load_sd_pipeline() if self.sd_pipeline is None: logger.error("无法生成图像:Stable Diffusion 模型未加载") return self.create_placeholder_image(width, height) result = self.sd_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width ) return result.images[0] def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0): """使用ControlNet生成3D试穿效果 - 更精细的模型""" if self.controlnet_pipeline is None: self.load_controlnet_pipeline() if self.controlnet_pipeline is None: logger.error("无法生成3D试穿:ControlNet 模型未加载") return self.create_placeholder_image(512, 768) if reference_image is not None: prompt = f"{prompt}, based on reference design" result = self.controlnet_pipeline( prompt=prompt, image=image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) return result.images[0] def create_placeholder_image(self, width, height): """创建占位图像""" color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200)) return Image.new('RGB', (width, height), color=color) def cleanup(self): """释放模型占用显存和缓存""" logger.info("释放模型占用显存和缓存...") try: del self.caption_model del self.clip_model del self.sd_pipeline del self.controlnet del self.controlnet_pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("显存和缓存清理完成") except Exception as e: logger.error(f"清理显存失败: {e}")