| import torch |
| from PIL import Image |
| import numpy as np |
| from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel |
| from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler |
| import os |
| import logging |
| import time |
| import random |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class ModelManager: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"使用设备: {self.device}") |
|
|
| |
| self.model_config = { |
| "caption_model": "Salesforce/blip-image-captioning-large", |
| "clip_model": "openai/clip-vit-large-patch14", |
| "sd_model": "runwayml/stable-diffusion-v1-5", |
| "controlnet_model": "lllyasviel/control_v11p_sd15_openpose" |
| } |
|
|
| |
| self.caption_processor = None |
| self.caption_model = None |
| self.clip_processor = None |
| self.clip_model = None |
| self.sd_pipeline = None |
| self.controlnet = None |
| self.controlnet_pipeline = None |
|
|
| |
| self.load_all_models() |
|
|
| def load_all_models(self): |
| self.load_caption_model() |
| self.load_clip_model() |
| self.load_sd_pipeline() |
| self.load_controlnet_pipeline() |
|
|
| def load_caption_model(self): |
| try: |
| logger.info("加载 BLIP 图像描述模型...") |
| self.caption_processor = BlipProcessor.from_pretrained( |
| self.model_config["caption_model"], |
| cache_dir="/tmp/models" |
| ) |
| self.caption_model = BlipForConditionalGeneration.from_pretrained( |
| self.model_config["caption_model"], |
| cache_dir="/tmp/models", |
| torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 |
| ).to(self.device) |
| logger.info("BLIP 模型加载完成") |
| except Exception as e: |
| logger.error(f"BLIP 模型加载失败: {e}") |
|
|
| def load_clip_model(self): |
| try: |
| logger.info("加载 CLIP 模型...") |
| self.clip_processor = CLIPProcessor.from_pretrained( |
| self.model_config["clip_model"], |
| cache_dir="/tmp/models" |
| ) |
| self.clip_model = CLIPModel.from_pretrained( |
| self.model_config["clip_model"], |
| cache_dir="/tmp/models", |
| torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 |
| ).to(self.device) |
| logger.info("CLIP 模型加载完成") |
| except Exception as e: |
| logger.error(f"CLIP 模型加载失败: {e}") |
|
|
| def load_sd_pipeline(self): |
| try: |
| logger.info("加载 Stable Diffusion Pipeline...") |
| self.sd_pipeline = StableDiffusionPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, |
| cache_dir="/tmp/models", |
| safety_checker=None, |
| use_safetensors=True |
| ) |
| self.sd_pipeline = self.sd_pipeline.to(self.device) |
| self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config) |
| logger.info("Stable Diffusion Pipeline 加载完成") |
| except Exception as e: |
| logger.error(f"Stable Diffusion Pipeline 加载失败: {e}") |
|
|
| def load_controlnet_pipeline(self): |
| try: |
| logger.info("加载 ControlNet 模型和 Pipeline...") |
| self.controlnet = ControlNetModel.from_pretrained( |
| self.model_config["controlnet_model"], |
| cache_dir="/tmp/models", |
| torch_dtype=torch.float16 if self.device=="cuda" else torch.float32 |
| ).to(self.device) |
|
|
| self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained( |
| self.model_config["sd_model"], |
| controlnet=self.controlnet, |
| cache_dir="/tmp/models", |
| torch_dtype=torch.float16 if self.device=="cuda" else torch.float32, |
| safety_checker=None |
| ).to(self.device) |
|
|
| self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config) |
| logger.info("ControlNet Pipeline 加载完成") |
| except Exception as e: |
| logger.error(f"ControlNet Pipeline 加载失败: {e}") |
|
|
| |
|
|
| def generate_caption(self, image): |
| """使用BLIP模型生成图像描述""" |
| if self.caption_model is None or self.caption_processor is None: |
| self.load_caption_model() |
|
|
| inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| outputs = self.caption_model.generate(**inputs, max_length=50) |
| caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True) |
| return caption |
|
|
| def analyze_style(self, image): |
| """使用CLIP模型分析服装风格""" |
| if self.clip_model is None or self.clip_processor is None: |
| self.load_clip_model() |
| |
| styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"] |
| |
| inputs = self.clip_processor( |
| text=styles, |
| images=image, |
| return_tensors="pt", |
| padding=True |
| ).to(self.device) |
| |
| with torch.no_grad(): |
| outputs = self.clip_model(**inputs) |
| logits_per_image = outputs.logits_per_image |
| probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] |
| |
| style_scores = {style: float(prob) for style, prob in zip(styles, probs)} |
| return style_scores |
|
|
| def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512): |
| """使用Stable Diffusion生成设计图像""" |
| if self.sd_pipeline is None: |
| self.load_sd_pipeline() |
| if self.sd_pipeline is None: |
| logger.error("无法生成图像:Stable Diffusion 模型未加载") |
| return self.create_placeholder_image(width, height) |
|
|
| result = self.sd_pipeline( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| height=height, |
| width=width |
| ) |
| return result.images[0] |
|
|
| def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0): |
| """使用ControlNet生成3D试穿效果 - 更精细的模型""" |
| if self.controlnet_pipeline is None: |
| self.load_controlnet_pipeline() |
| if self.controlnet_pipeline is None: |
| logger.error("无法生成3D试穿:ControlNet 模型未加载") |
| return self.create_placeholder_image(512, 768) |
| |
| if reference_image is not None: |
| prompt = f"{prompt}, based on reference design" |
| |
| result = self.controlnet_pipeline( |
| prompt=prompt, |
| image=image, |
| negative_prompt=negative_prompt, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| ) |
| return result.images[0] |
|
|
| def create_placeholder_image(self, width, height): |
| """创建占位图像""" |
| color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200)) |
| return Image.new('RGB', (width, height), color=color) |
|
|
| def cleanup(self): |
| """释放模型占用显存和缓存""" |
| logger.info("释放模型占用显存和缓存...") |
| try: |
| del self.caption_model |
| del self.clip_model |
| del self.sd_pipeline |
| del self.controlnet |
| del self.controlnet_pipeline |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info("显存和缓存清理完成") |
| except Exception as e: |
| logger.error(f"清理显存失败: {e}") |
|
|