# 完整的 modal_manager.py (即之前的 model_manager.py 完整实现,路径改为 model/modal_manager.py 可直接替换) # 包含三视图打板一致性、手稿风格生成、多角度 3D 试穿支持、显存优化等全部功能 import torch from PIL import Image import numpy as np from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler import os import logging import time import random import gc logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {self.device}") self.model_config = { "caption_model": "Salesforce/blip-image-captioning-large", "clip_model": "openai/clip-vit-large-patch14", "sd_model": "runwayml/stable-diffusion-v1-5", "controlnet_model": "lllyasviel/control_v11p_sd15_openpose" } self.caption_processor = None self.caption_model = None self.clip_processor = None self.clip_model = None self.sd_pipeline = None self.controlnet = None self.controlnet_pipeline = None self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.enable_attention_slicing = True self.enable_cpu_offload = False try: self.load_all_models() except Exception as e: logger.warning(f"加载模型时出错: {e}") def optimize_memory_usage(self): if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def load_all_models(self): self.optimize_memory_usage() self.load_caption_model() self.load_clip_model() self.load_sd_pipeline() self.load_controlnet_pipeline() logger.info("所有模型加载完成") def load_caption_model(self): self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models") self.caption_model = BlipForConditionalGeneration.from_pretrained( self.model_config["caption_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype, low_cpu_mem_usage=True ).to(self.device) self.caption_model.enable_attention_slicing() self.caption_model.eval() def load_clip_model(self): self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models") self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device) self.clip_model.eval() def load_sd_pipeline(self): self.sd_pipeline = StableDiffusionPipeline.from_pretrained( self.model_config["sd_model"], torch_dtype=self.torch_dtype, cache_dir="/tmp/models", safety_checker=None, requires_safety_checker=False, use_safetensors=True, low_cpu_mem_usage=True ).to(self.device) self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config) if self.enable_attention_slicing: self.sd_pipeline.enable_attention_slicing() try: self.sd_pipeline.enable_xformers_memory_efficient_attention() except Exception: pass self.sd_pipeline.enable_vae_slicing() self.sd_pipeline.safety_checker = None def load_controlnet_pipeline(self): self.controlnet = ControlNetModel.from_pretrained( self.model_config["controlnet_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype, low_cpu_mem_usage=True ).to(self.device) self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( self.model_config["sd_model"], controlnet=self.controlnet, cache_dir="/tmp/models", torch_dtype=self.torch_dtype, safety_checker=None, requires_safety_checker=False, low_cpu_mem_usage=True ).to(self.device) self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config) if self.enable_attention_slicing: self.controlnet_pipeline.enable_attention_slicing() try: self.controlnet_pipeline.enable_xformers_memory_efficient_attention() except Exception: pass self.controlnet_pipeline.enable_vae_slicing() @torch.no_grad() def generate_caption(self, image): if image.mode != 'RGB': image = image.convert('RGB') if image.width > 512 or image.height > 512: image.thumbnail((512, 512), Image.Resampling.LANCZOS) inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True) caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True) del inputs, outputs if torch.cuda.is_available(): torch.cuda.empty_cache() return caption @torch.no_grad() def analyze_style(self, image): style_labels = [ "business formal suit professional attire", "casual comfortable everyday wear", "athletic sportswear activewear", "fashion trendy modern stylish", "vintage retro classic style", "streetwear urban contemporary", "elegant sophisticated refined" ] style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"] if image.mode != 'RGB': image = image.convert('RGB') if image.width > 224 or image.height > 224: image.thumbnail((224, 224), Image.Resampling.LANCZOS) inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device) outputs = self.clip_model(**inputs) probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0] return {name: float(prob) for name, prob in zip(style_names, probs)} @torch.no_grad() def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None): if negative_prompt is None: negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed" width = (width // 8) * 8 height = (height // 8) * 8 gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None result = self.sd_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, generator=gen ) if torch.cuda.is_available(): torch.cuda.empty_cache() return result.images[0] @torch.no_grad() def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768): if image.mode != 'RGB': image = image.convert('RGB') control_image = image.resize((512, 768), Image.Resampling.LANCZOS) if negative_prompt is None: negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people" prompt_with_angle = f"{prompt}, view from {angle} degrees" if reference_image is not None: prompt_with_angle = f"{prompt_with_angle}, based on provided reference design" gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle)) result = self.controlnet_pipeline( prompt=prompt_with_angle, image=control_image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=1.0, generator=gen ) if torch.cuda.is_available(): torch.cuda.empty_cache() return result.images[0] def create_placeholder_image(self, width, height): color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)]) return Image.new('RGB', (width, height), color=color) def cleanup(self): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() try: torch.cuda.ipc_collect() except Exception: pass def get_model_status(self): status = { "caption_model": self.caption_model is not None, "clip_model": self.clip_model is not None, "sd_pipeline": self.sd_pipeline is not None, "controlnet_pipeline": self.controlnet_pipeline is not None, "device": self.device } if torch.cuda.is_available(): status["gpu_memory"] = { "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB", "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB" } return status