# models/model_manager.py import torch from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler import os import logging import time logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {self.device}") # 模型配置(更新了 SD 模型路径) self.model_config = { "caption_model": "Salesforce/blip-image-captioning-base", "clip_model": "openai/clip-vit-base-patch32", "sd_model": "runwayml/stable-diffusion-v1-5", # 这里用原版,可替换为镜像 "controlnet_model": "lllyasviel/sd-controlnet-openpose" } # 模型容器 self.caption_processor = None self.caption_model = None self.clip_processor = None self.clip_model = None self.sd_pipeline = None self.controlnet = None self.controlnet_pipeline = None # 预加载所有模型 self.load_all_models() def load_all_models(self): self.load_caption_model() self.load_clip_model() self.load_sd_pipeline() self.load_controlnet_pipeline() def load_caption_model(self): try: logger.info("加载 BLIP 图像描述模型...") self.caption_processor = BlipProcessor.from_pretrained( self.model_config["caption_model"], cache_dir="/tmp/models" ) self.caption_model = BlipForConditionalGeneration.from_pretrained( self.model_config["caption_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) logger.info("BLIP 模型加载完成") except Exception as e: logger.error(f"BLIP 模型加载失败: {e}") def load_clip_model(self): try: logger.info("加载 CLIP 模型...") self.clip_processor = CLIPProcessor.from_pretrained( self.model_config["clip_model"], cache_dir="/tmp/models" ) self.clip_model = CLIPModel.from_pretrained( self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) logger.info("CLIP 模型加载完成") except Exception as e: logger.error(f"CLIP 模型加载失败: {e}") def load_sd_pipeline(self): try: logger.info("加载 Stable Diffusion Pipeline...") self.sd_pipeline = StableDiffusionPipeline.from_pretrained( self.model_config["sd_model"], revision="fp16" if self.device=="cuda" else None, torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, cache_dir="/tmp/models", safety_checker=None # 可按需配置安全检查器 ) self.sd_pipeline = self.sd_pipeline.to(self.device) # 用更高效的调度器 self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config) logger.info("Stable Diffusion Pipeline 加载完成") except Exception as e: logger.error(f"Stable Diffusion Pipeline 加载失败: {e}") def load_controlnet_pipeline(self): try: logger.info("加载 ControlNet 模型和 Pipeline...") self.controlnet = ControlNetModel.from_pretrained( self.model_config["controlnet_model"], cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 ).to(self.device) self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( self.model_config["sd_model"], controlnet=self.controlnet, cache_dir="/tmp/models", torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, safety_checker=None ).to(self.device) self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config) logger.info("ControlNet Pipeline 加载完成") except Exception as e: logger.error(f"ControlNet Pipeline 加载失败: {e}") # 下面是真正调用模型的接口 def generate_caption(self, image): if self.caption_model is None or self.caption_processor is None: self.load_caption_model() inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.caption_model.generate(**inputs, max_length=50) caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True) return caption def analyze_style(self, image): if self.clip_model is None or self.clip_processor is None: self.load_clip_model() inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.clip_model.get_image_features(**inputs) features = outputs.cpu().numpy()[0] # 简单归一化(范例) norm = features / (np.linalg.norm(features) + 1e-10) style_score = { "clip_feature_vector": norm } return style_score def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512): if self.sd_pipeline is None: self.load_sd_pipeline() # Stable Diffusion 生成图像 result = self.sd_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width ) return result.images[0] def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0): if self.controlnet_pipeline is None: self.load_controlnet_pipeline() # 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图) result = self.controlnet_pipeline( prompt=prompt, image=image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) return result.images[0] def cleanup(self): logger.info("释放模型占用显存和缓存...") try: del self.caption_model del self.caption_processor del self.clip_model del self.clip_processor del self.sd_pipeline del self.controlnet del self.controlnet_pipeline torch.cuda.empty_cache() import gc gc.collect() logger.info("显存清理完成") except Exception as e: logger.error(f"清理显存失败: {e}")